From 5f8ccd5789137363e035d1dfb9a05d3b9bf3ce6b Mon Sep 17 00:00:00 2001 From: Ji Yan Date: Thu, 9 Mar 2017 21:30:11 -0800 Subject: [PATCH 001/512] respect both gpu and maxgpu --- .../mesos/MesosCoarseGrainedSchedulerBackend.scala | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala index f555072c3842..ecd499100311 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala @@ -62,6 +62,7 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( val useFetcherCache = conf.getBoolean("spark.mesos.fetcherCache.enable", false) val maxGpus = conf.getInt("spark.mesos.gpus.max", 0) + val gpuCores = conf.getInt("spark.mesos.gpus", 0) private[this] val shutdownTimeoutMS = conf.getTimeAsMs("spark.mesos.coarse.shutdownTimeout", "10s") @@ -401,9 +402,11 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( launchTasks = true val taskId = newMesosTaskId() val offerCPUs = getResource(resources, "cpus").toInt - val taskGPUs = Math.min( - Math.max(0, maxGpus - totalGpusAcquired), getResource(resources, "gpus").toInt) - + var taskGPUs = Math.min(Math.max(0, maxGpus - totalGpusAcquired), + getResource(resources, "gpus").toInt) + if (gpuCores > 0) { + taskGPUs = gpuCores + } val taskCPUs = executorCores(offerCPUs) val taskMemory = executorMemory(sc) @@ -462,6 +465,7 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( private def canLaunchTask(slaveId: String, resources: JList[Resource]): Boolean = { val offerMem = getResource(resources, "mem") val offerCPUs = getResource(resources, "cpus").toInt + val offerGPUs = getResource(resources, "gpus").toInt val cpus = executorCores(offerCPUs) val mem = executorMemory(sc) val ports = getRangeResource(resources, "ports") @@ -471,6 +475,7 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( cpus <= offerCPUs && cpus + totalCoresAcquired <= maxCores && mem <= offerMem && + gpuCores <= offerGPUs && numExecutors() < executorLimit && slaves.get(slaveId).map(_.taskFailures).getOrElse(0) < MAX_SLAVE_FAILURES && meetsPortRequirements From 5949e6c4477fd3cb07a6962dbee48b4416ea65dd Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Thu, 9 Mar 2017 22:58:52 -0800 Subject: [PATCH 002/512] [SPARK-19008][SQL] Improve performance of Dataset.map by eliminating boxing/unboxing ## What changes were proposed in this pull request? This PR improve performance of Dataset.map() for primitive types by removing boxing/unbox operations. This is based on [the discussion](https://github.com/apache/spark/pull/16391#discussion_r93788919) with cloud-fan. Current Catalyst generates a method call to a `apply()` method of an anonymous function written in Scala. The types of an argument and return value are `java.lang.Object`. As a result, each method call for a primitive value involves a pair of unboxing and boxing for calling this `apply()` method and a pair of boxing and unboxing for returning from this `apply()` method. This PR directly calls a specialized version of a `apply()` method without boxing and unboxing. For example, if types of an arguments ant return value is `int`, this PR generates a method call to `apply$mcII$sp`. This PR supports any combination of `Int`, `Long`, `Float`, and `Double`. The following is a benchmark result using [this program](https://github.com/apache/spark/pull/16391/files) with 4.7x. Here is a Dataset part of this program. Without this PR ``` OpenJDK 64-Bit Server VM 1.8.0_111-8u111-b14-2ubuntu0.16.04.2-b14 on Linux 4.4.0-47-generic Intel(R) Xeon(R) CPU E5-2667 v3 3.20GHz back-to-back map: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ RDD 1923 / 1952 52.0 19.2 1.0X DataFrame 526 / 548 190.2 5.3 3.7X Dataset 3094 / 3154 32.3 30.9 0.6X ``` With this PR ``` OpenJDK 64-Bit Server VM 1.8.0_111-8u111-b14-2ubuntu0.16.04.2-b14 on Linux 4.4.0-47-generic Intel(R) Xeon(R) CPU E5-2667 v3 3.20GHz back-to-back map: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ RDD 1883 / 1892 53.1 18.8 1.0X DataFrame 502 / 642 199.1 5.0 3.7X Dataset 657 / 784 152.2 6.6 2.9X ``` ```java def backToBackMap(spark: SparkSession, numRows: Long, numChains: Int): Benchmark = { import spark.implicits._ val rdd = spark.sparkContext.range(0, numRows) val ds = spark.range(0, numRows) val func = (l: Long) => l + 1 val benchmark = new Benchmark("back-to-back map", numRows) ... benchmark.addCase("Dataset") { iter => var res = ds.as[Long] var i = 0 while (i < numChains) { res = res.map(func) i += 1 } res.queryExecution.toRdd.foreach(_ => Unit) } benchmark } ``` A motivating example ```java Seq(1, 2, 3).toDS.map(i => i * 7).show ``` Generated code without this PR ```java /* 005 */ final class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator { /* 006 */ private Object[] references; /* 007 */ private scala.collection.Iterator[] inputs; /* 008 */ private scala.collection.Iterator inputadapter_input; /* 009 */ private UnsafeRow deserializetoobject_result; /* 010 */ private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder deserializetoobject_holder; /* 011 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter deserializetoobject_rowWriter; /* 012 */ private int mapelements_argValue; /* 013 */ private UnsafeRow mapelements_result; /* 014 */ private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder mapelements_holder; /* 015 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter mapelements_rowWriter; /* 016 */ private UnsafeRow serializefromobject_result; /* 017 */ private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder serializefromobject_holder; /* 018 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter serializefromobject_rowWriter; /* 019 */ /* 020 */ public GeneratedIterator(Object[] references) { /* 021 */ this.references = references; /* 022 */ } /* 023 */ /* 024 */ public void init(int index, scala.collection.Iterator[] inputs) { /* 025 */ partitionIndex = index; /* 026 */ this.inputs = inputs; /* 027 */ inputadapter_input = inputs[0]; /* 028 */ deserializetoobject_result = new UnsafeRow(1); /* 029 */ this.deserializetoobject_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(deserializetoobject_result, 0); /* 030 */ this.deserializetoobject_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(deserializetoobject_holder, 1); /* 031 */ /* 032 */ mapelements_result = new UnsafeRow(1); /* 033 */ this.mapelements_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(mapelements_result, 0); /* 034 */ this.mapelements_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(mapelements_holder, 1); /* 035 */ serializefromobject_result = new UnsafeRow(1); /* 036 */ this.serializefromobject_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(serializefromobject_result, 0); /* 037 */ this.serializefromobject_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(serializefromobject_holder, 1); /* 038 */ /* 039 */ } /* 040 */ /* 041 */ protected void processNext() throws java.io.IOException { /* 042 */ while (inputadapter_input.hasNext() && !stopEarly()) { /* 043 */ InternalRow inputadapter_row = (InternalRow) inputadapter_input.next(); /* 044 */ int inputadapter_value = inputadapter_row.getInt(0); /* 045 */ /* 046 */ boolean mapelements_isNull = true; /* 047 */ int mapelements_value = -1; /* 048 */ if (!false) { /* 049 */ mapelements_argValue = inputadapter_value; /* 050 */ /* 051 */ mapelements_isNull = false; /* 052 */ if (!mapelements_isNull) { /* 053 */ Object mapelements_funcResult = null; /* 054 */ mapelements_funcResult = ((scala.Function1) references[0]).apply(mapelements_argValue); /* 055 */ if (mapelements_funcResult == null) { /* 056 */ mapelements_isNull = true; /* 057 */ } else { /* 058 */ mapelements_value = (Integer) mapelements_funcResult; /* 059 */ } /* 060 */ /* 061 */ } /* 062 */ /* 063 */ } /* 064 */ /* 065 */ serializefromobject_rowWriter.zeroOutNullBytes(); /* 066 */ /* 067 */ if (mapelements_isNull) { /* 068 */ serializefromobject_rowWriter.setNullAt(0); /* 069 */ } else { /* 070 */ serializefromobject_rowWriter.write(0, mapelements_value); /* 071 */ } /* 072 */ append(serializefromobject_result); /* 073 */ if (shouldStop()) return; /* 074 */ } /* 075 */ } /* 076 */ } ``` Generated code with this PR (lines 48-56 are changed) ```java /* 005 */ final class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator { /* 006 */ private Object[] references; /* 007 */ private scala.collection.Iterator[] inputs; /* 008 */ private scala.collection.Iterator inputadapter_input; /* 009 */ private UnsafeRow deserializetoobject_result; /* 010 */ private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder deserializetoobject_holder; /* 011 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter deserializetoobject_rowWriter; /* 012 */ private int mapelements_argValue; /* 013 */ private UnsafeRow mapelements_result; /* 014 */ private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder mapelements_holder; /* 015 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter mapelements_rowWriter; /* 016 */ private UnsafeRow serializefromobject_result; /* 017 */ private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder serializefromobject_holder; /* 018 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter serializefromobject_rowWriter; /* 019 */ /* 020 */ public GeneratedIterator(Object[] references) { /* 021 */ this.references = references; /* 022 */ } /* 023 */ /* 024 */ public void init(int index, scala.collection.Iterator[] inputs) { /* 025 */ partitionIndex = index; /* 026 */ this.inputs = inputs; /* 027 */ inputadapter_input = inputs[0]; /* 028 */ deserializetoobject_result = new UnsafeRow(1); /* 029 */ this.deserializetoobject_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(deserializetoobject_result, 0); /* 030 */ this.deserializetoobject_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(deserializetoobject_holder, 1); /* 031 */ /* 032 */ mapelements_result = new UnsafeRow(1); /* 033 */ this.mapelements_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(mapelements_result, 0); /* 034 */ this.mapelements_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(mapelements_holder, 1); /* 035 */ serializefromobject_result = new UnsafeRow(1); /* 036 */ this.serializefromobject_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(serializefromobject_result, 0); /* 037 */ this.serializefromobject_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(serializefromobject_holder, 1); /* 038 */ /* 039 */ } /* 040 */ /* 041 */ protected void processNext() throws java.io.IOException { /* 042 */ while (inputadapter_input.hasNext() && !stopEarly()) { /* 043 */ InternalRow inputadapter_row = (InternalRow) inputadapter_input.next(); /* 044 */ int inputadapter_value = inputadapter_row.getInt(0); /* 045 */ /* 046 */ boolean mapelements_isNull = true; /* 047 */ int mapelements_value = -1; /* 048 */ if (!false) { /* 049 */ mapelements_argValue = inputadapter_value; /* 050 */ /* 051 */ mapelements_isNull = false; /* 052 */ if (!mapelements_isNull) { /* 053 */ mapelements_value = ((scala.Function1) references[0]).apply$mcII$sp(mapelements_argValue); /* 054 */ } /* 055 */ /* 056 */ } /* 057 */ /* 058 */ serializefromobject_rowWriter.zeroOutNullBytes(); /* 059 */ /* 060 */ if (mapelements_isNull) { /* 061 */ serializefromobject_rowWriter.setNullAt(0); /* 062 */ } else { /* 063 */ serializefromobject_rowWriter.write(0, mapelements_value); /* 064 */ } /* 065 */ append(serializefromobject_result); /* 066 */ if (shouldStop()) return; /* 067 */ } /* 068 */ } /* 069 */ } ``` Java bytecode for methods for `i => i * 7` ```java $ javap -c Test\$\$anonfun\$5\$\$anonfun\$apply\$mcV\$sp\$1.class Compiled from "Test.scala" public final class org.apache.spark.sql.Test$$anonfun$5$$anonfun$apply$mcV$sp$1 extends scala.runtime.AbstractFunction1$mcII$sp implements scala.Serializable { public static final long serialVersionUID; public final int apply(int); Code: 0: aload_0 1: iload_1 2: invokevirtual #18 // Method apply$mcII$sp:(I)I 5: ireturn public int apply$mcII$sp(int); Code: 0: iload_1 1: bipush 7 3: imul 4: ireturn public final java.lang.Object apply(java.lang.Object); Code: 0: aload_0 1: aload_1 2: invokestatic #29 // Method scala/runtime/BoxesRunTime.unboxToInt:(Ljava/lang/Object;)I 5: invokevirtual #31 // Method apply:(I)I 8: invokestatic #35 // Method scala/runtime/BoxesRunTime.boxToInteger:(I)Ljava/lang/Integer; 11: areturn public org.apache.spark.sql.Test$$anonfun$5$$anonfun$apply$mcV$sp$1(org.apache.spark.sql.Test$$anonfun$5); Code: 0: aload_0 1: invokespecial #42 // Method scala/runtime/AbstractFunction1$mcII$sp."":()V 4: return } ``` ## How was this patch tested? Added new test suites to `DatasetPrimitiveSuite`. Author: Kazuaki Ishizaki Closes #17172 from kiszk/SPARK-19008. --- .../sql/catalyst/plans/logical/object.scala | 38 +++++- .../apache/spark/sql/execution/objects.scala | 6 +- .../apache/spark/sql/DatasetBenchmark.scala | 122 +++++++++++++++++- .../spark/sql/DatasetPrimitiveSuite.scala | 51 ++++++++ 4 files changed, 208 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala index 617239f56cdd..7f4462e58360 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.objects.Invoke import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types._ +import org.apache.spark.util.Utils object CatalystSerde { def deserialize[T : Encoder](child: LogicalPlan): DeserializeToObject = { @@ -211,13 +212,48 @@ case class TypedFilter( def typedCondition(input: Expression): Expression = { val (funcClass, methodName) = func match { case m: FilterFunction[_] => classOf[FilterFunction[_]] -> "call" - case _ => classOf[Any => Boolean] -> "apply" + case _ => FunctionUtils.getFunctionOneName(BooleanType, input.dataType) } val funcObj = Literal.create(func, ObjectType(funcClass)) Invoke(funcObj, methodName, BooleanType, input :: Nil) } } +object FunctionUtils { + private def getMethodType(dt: DataType, isOutput: Boolean): Option[String] = { + dt match { + case BooleanType if isOutput => Some("Z") + case IntegerType => Some("I") + case LongType => Some("J") + case FloatType => Some("F") + case DoubleType => Some("D") + case _ => None + } + } + + def getFunctionOneName(outputDT: DataType, inputDT: DataType): (Class[_], String) = { + // load "scala.Function1" using Java API to avoid requirements of type parameters + Utils.classForName("scala.Function1") -> { + // if a pair of an argument and return types is one of specific types + // whose specialized method (apply$mc..$sp) is generated by scalac, + // Catalyst generated a direct method call to the specialized method. + // The followings are references for this specialization: + // http://www.scala-lang.org/api/2.12.0/scala/Function1.html + // https://github.com/scala/scala/blob/2.11.x/src/compiler/scala/tools/nsc/transform/ + // SpecializeTypes.scala + // http://www.cakesolutions.net/teamblogs/scala-dissection-functions + // http://axel22.github.io/2013/11/03/specialization-quirks.html + val inputType = getMethodType(inputDT, false) + val outputType = getMethodType(outputDT, true) + if (inputType.isDefined && outputType.isDefined) { + s"apply$$mc${outputType.get}${inputType.get}$$sp" + } else { + "apply" + } + } + } +} + /** Factory for constructing new `AppendColumn` nodes. */ object AppendColumns { def apply[T : Encoder, U : Encoder]( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala index 199ba5ce6969..fdd1bcc94be2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala @@ -28,11 +28,13 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.objects.Invoke +import org.apache.spark.sql.catalyst.plans.logical.FunctionUtils import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.plans.logical.LogicalKeyedState import org.apache.spark.sql.execution.streaming.KeyedStateImpl -import org.apache.spark.sql.types.{DataType, ObjectType, StructType} +import org.apache.spark.sql.types._ +import org.apache.spark.util.Utils /** @@ -219,7 +221,7 @@ case class MapElementsExec( override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { val (funcClass, methodName) = func match { case m: MapFunction[_, _] => classOf[MapFunction[_, _]] -> "call" - case _ => classOf[Any => Any] -> "apply" + case _ => FunctionUtils.getFunctionOneName(outputObjAttr.dataType, child.output(0).dataType) } val funcObj = Literal.create(func, ObjectType(funcClass)) val callFunc = Invoke(funcObj, methodName, outputObjAttr.dataType, child.output) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala index 66d94d601605..1a0672b8876d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala @@ -31,6 +31,49 @@ object DatasetBenchmark { case class Data(l: Long, s: String) + def backToBackMapLong(spark: SparkSession, numRows: Long, numChains: Int): Benchmark = { + import spark.implicits._ + + val rdd = spark.sparkContext.range(0, numRows) + val ds = spark.range(0, numRows) + val df = ds.toDF("l") + val func = (l: Long) => l + 1 + + val benchmark = new Benchmark("back-to-back map long", numRows) + + benchmark.addCase("RDD") { iter => + var res = rdd + var i = 0 + while (i < numChains) { + res = res.map(func) + i += 1 + } + res.foreach(_ => Unit) + } + + benchmark.addCase("DataFrame") { iter => + var res = df + var i = 0 + while (i < numChains) { + res = res.select($"l" + 1 as "l") + i += 1 + } + res.queryExecution.toRdd.foreach(_ => Unit) + } + + benchmark.addCase("Dataset") { iter => + var res = ds.as[Long] + var i = 0 + while (i < numChains) { + res = res.map(func) + i += 1 + } + res.queryExecution.toRdd.foreach(_ => Unit) + } + + benchmark + } + def backToBackMap(spark: SparkSession, numRows: Long, numChains: Int): Benchmark = { import spark.implicits._ @@ -72,6 +115,49 @@ object DatasetBenchmark { benchmark } + def backToBackFilterLong(spark: SparkSession, numRows: Long, numChains: Int): Benchmark = { + import spark.implicits._ + + val rdd = spark.sparkContext.range(1, numRows) + val ds = spark.range(1, numRows) + val df = ds.toDF("l") + val func = (l: Long) => l % 2L == 0L + + val benchmark = new Benchmark("back-to-back filter Long", numRows) + + benchmark.addCase("RDD") { iter => + var res = rdd + var i = 0 + while (i < numChains) { + res = res.filter(func) + i += 1 + } + res.foreach(_ => Unit) + } + + benchmark.addCase("DataFrame") { iter => + var res = df + var i = 0 + while (i < numChains) { + res = res.filter($"l" % 2L === 0L) + i += 1 + } + res.queryExecution.toRdd.foreach(_ => Unit) + } + + benchmark.addCase("Dataset") { iter => + var res = ds.as[Long] + var i = 0 + while (i < numChains) { + res = res.filter(func) + i += 1 + } + res.queryExecution.toRdd.foreach(_ => Unit) + } + + benchmark + } + def backToBackFilter(spark: SparkSession, numRows: Long, numChains: Int): Benchmark = { import spark.implicits._ @@ -165,9 +251,22 @@ object DatasetBenchmark { val numRows = 100000000 val numChains = 10 - val benchmark = backToBackMap(spark, numRows, numChains) - val benchmark2 = backToBackFilter(spark, numRows, numChains) - val benchmark3 = aggregate(spark, numRows) + val benchmark0 = backToBackMapLong(spark, numRows, numChains) + val benchmark1 = backToBackMap(spark, numRows, numChains) + val benchmark2 = backToBackFilterLong(spark, numRows, numChains) + val benchmark3 = backToBackFilter(spark, numRows, numChains) + val benchmark4 = aggregate(spark, numRows) + + /* + OpenJDK 64-Bit Server VM 1.8.0_111-8u111-b14-2ubuntu0.16.04.2-b14 on Linux 4.4.0-47-generic + Intel(R) Xeon(R) CPU E5-2667 v3 @ 3.20GHz + back-to-back map long: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + RDD 1883 / 1892 53.1 18.8 1.0X + DataFrame 502 / 642 199.1 5.0 3.7X + Dataset 657 / 784 152.2 6.6 2.9X + */ + benchmark0.run() /* OpenJDK 64-Bit Server VM 1.8.0_91-b14 on Linux 3.10.0-327.18.2.el7.x86_64 @@ -178,7 +277,18 @@ object DatasetBenchmark { DataFrame 2647 / 3116 37.8 26.5 1.3X Dataset 4781 / 5155 20.9 47.8 0.7X */ - benchmark.run() + benchmark1.run() + + /* + OpenJDK 64-Bit Server VM 1.8.0_121-8u121-b13-0ubuntu1.16.04.2-b13 on Linux 4.4.0-47-generic + Intel(R) Xeon(R) CPU E5-2667 v3 @ 3.20GHz + back-to-back filter Long: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + RDD 846 / 1120 118.1 8.5 1.0X + DataFrame 270 / 329 370.9 2.7 3.1X + Dataset 545 / 789 183.5 5.4 1.6X + */ + benchmark2.run() /* OpenJDK 64-Bit Server VM 1.8.0_91-b14 on Linux 3.10.0-327.18.2.el7.x86_64 @@ -189,7 +299,7 @@ object DatasetBenchmark { DataFrame 59 / 72 1695.4 0.6 22.8X Dataset 2777 / 2805 36.0 27.8 0.5X */ - benchmark2.run() + benchmark3.run() /* Java HotSpot(TM) 64-Bit Server VM 1.8.0_60-b27 on Mac OS X 10.12.1 @@ -201,6 +311,6 @@ object DatasetBenchmark { Dataset sum using Aggregator 4656 / 4758 21.5 46.6 0.4X Dataset complex Aggregator 6636 / 7039 15.1 66.4 0.3X */ - benchmark3.run() + benchmark4.run() } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala index 6b50cb3e48c7..82b707537e45 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala @@ -62,6 +62,40 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext { 2, 3, 4) } + test("mapPrimitive") { + val dsInt = Seq(1, 2, 3).toDS() + checkDataset(dsInt.map(_ > 1), false, true, true) + checkDataset(dsInt.map(_ + 1), 2, 3, 4) + checkDataset(dsInt.map(_ + 8589934592L), 8589934593L, 8589934594L, 8589934595L) + checkDataset(dsInt.map(_ + 1.1F), 2.1F, 3.1F, 4.1F) + checkDataset(dsInt.map(_ + 1.23D), 2.23D, 3.23D, 4.23D) + + val dsLong = Seq(1L, 2L, 3L).toDS() + checkDataset(dsLong.map(_ > 1), false, true, true) + checkDataset(dsLong.map(e => (e + 1).toInt), 2, 3, 4) + checkDataset(dsLong.map(_ + 8589934592L), 8589934593L, 8589934594L, 8589934595L) + checkDataset(dsLong.map(_ + 1.1F), 2.1F, 3.1F, 4.1F) + checkDataset(dsLong.map(_ + 1.23D), 2.23D, 3.23D, 4.23D) + + val dsFloat = Seq(1F, 2F, 3F).toDS() + checkDataset(dsFloat.map(_ > 1), false, true, true) + checkDataset(dsFloat.map(e => (e + 1).toInt), 2, 3, 4) + checkDataset(dsFloat.map(e => (e + 123456L).toLong), 123457L, 123458L, 123459L) + checkDataset(dsFloat.map(_ + 1.1F), 2.1F, 3.1F, 4.1F) + checkDataset(dsFloat.map(_ + 1.23D), 2.23D, 3.23D, 4.23D) + + val dsDouble = Seq(1D, 2D, 3D).toDS() + checkDataset(dsDouble.map(_ > 1), false, true, true) + checkDataset(dsDouble.map(e => (e + 1).toInt), 2, 3, 4) + checkDataset(dsDouble.map(e => (e + 8589934592L).toLong), + 8589934593L, 8589934594L, 8589934595L) + checkDataset(dsDouble.map(e => (e + 1.1F).toFloat), 2.1F, 3.1F, 4.1F) + checkDataset(dsDouble.map(_ + 1.23D), 2.23D, 3.23D, 4.23D) + + val dsBoolean = Seq(true, false).toDS() + checkDataset(dsBoolean.map(e => !e), false, true) + } + test("filter") { val ds = Seq(1, 2, 3, 4).toDS() checkDataset( @@ -69,6 +103,23 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext { 2, 4) } + test("filterPrimitive") { + val dsInt = Seq(1, 2, 3).toDS() + checkDataset(dsInt.filter(_ > 1), 2, 3) + + val dsLong = Seq(1L, 2L, 3L).toDS() + checkDataset(dsLong.filter(_ > 1), 2L, 3L) + + val dsFloat = Seq(1F, 2F, 3F).toDS() + checkDataset(dsFloat.filter(_ > 1), 2F, 3F) + + val dsDouble = Seq(1D, 2D, 3D).toDS() + checkDataset(dsDouble.filter(_ > 1), 2D, 3D) + + val dsBoolean = Seq(true, false).toDS() + checkDataset(dsBoolean.filter(e => !e), false) + } + test("foreach") { val ds = Seq(1, 2, 3).toDS() val acc = sparkContext.longAccumulator From 501b7111997bc74754663348967104181b43319b Mon Sep 17 00:00:00 2001 From: Tyson Condie Date: Thu, 9 Mar 2017 23:02:13 -0800 Subject: [PATCH 003/512] [SPARK-19891][SS] Await Batch Lock notified on stream execution exit ## What changes were proposed in this pull request? We need to notify the await batch lock when the stream exits early e.g., when an exception has been thrown. ## How was this patch tested? Current tests that throw exceptions at runtime will finish faster as a result of this update. zsxwing Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Tyson Condie Closes #17231 from tcondie/kafka-writer. --- .../spark/sql/execution/streaming/StreamExecution.scala | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index 70912d13ae45..529263805c0a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -361,6 +361,13 @@ class StreamExecution( } } } finally { + awaitBatchLock.lock() + try { + // Wake up any threads that are waiting for the stream to progress. + awaitBatchLockCondition.signalAll() + } finally { + awaitBatchLock.unlock() + } terminationLatch.countDown() } } From fcb68e0f5d49234ac4527109887ff08cd4e1c29f Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Fri, 10 Mar 2017 18:04:37 +0100 Subject: [PATCH 004/512] [SPARK-19786][SQL] Facilitate loop optimizations in a JIT compiler regarding range() ## What changes were proposed in this pull request? This PR improves performance of operations with `range()` by changing Java code generated by Catalyst. This PR is inspired by the [blog article](https://databricks.com/blog/2017/02/16/processing-trillion-rows-per-second-single-machine-can-nested-loop-joins-fast.html). This PR changes generated code in the following two points. 1. Replace a while-loop with long instance variables a for-loop with int local varibles 2. Suppress generation of `shouldStop()` method if this method is unnecessary (e.g. `append()` is not generated). These points facilitates compiler optimizations in a JIT compiler by feeding the simplified Java code into the JIT compiler. The performance is improved by 7.6x. Benchmark program: ```java val N = 1 << 29 val iters = 2 val benchmark = new Benchmark("range.count", N * iters) benchmark.addCase(s"with this PR") { i => var n = 0 var len = 0 while (n < iters) { len += sparkSession.range(N).selectExpr("count(id)").collect.length n += 1 } } benchmark.run ``` Performance result without this PR ``` OpenJDK 64-Bit Server VM 1.8.0_111-8u111-b14-2ubuntu0.16.04.2-b14 on Linux 4.4.0-47-generic Intel(R) Xeon(R) CPU E5-2667 v3 3.20GHz range.count: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ w/o this PR 1349 / 1356 796.2 1.3 1.0X ``` Performance result with this PR ``` OpenJDK 64-Bit Server VM 1.8.0_111-8u111-b14-2ubuntu0.16.04.2-b14 on Linux 4.4.0-47-generic Intel(R) Xeon(R) CPU E5-2667 v3 3.20GHz range.count: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ with this PR 177 / 271 6065.3 0.2 1.0X ``` Here is a comparison between generated code w/o and with this PR. Only the method ```agg_doAggregateWithoutKey``` is changed. Generated code without this PR ```java /* 005 */ final class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator { /* 006 */ private Object[] references; /* 007 */ private scala.collection.Iterator[] inputs; /* 008 */ private boolean agg_initAgg; /* 009 */ private boolean agg_bufIsNull; /* 010 */ private long agg_bufValue; /* 011 */ private org.apache.spark.sql.execution.metric.SQLMetric range_numOutputRows; /* 012 */ private org.apache.spark.sql.execution.metric.SQLMetric range_numGeneratedRows; /* 013 */ private boolean range_initRange; /* 014 */ private long range_number; /* 015 */ private TaskContext range_taskContext; /* 016 */ private InputMetrics range_inputMetrics; /* 017 */ private long range_batchEnd; /* 018 */ private long range_numElementsTodo; /* 019 */ private scala.collection.Iterator range_input; /* 020 */ private UnsafeRow range_result; /* 021 */ private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder range_holder; /* 022 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter range_rowWriter; /* 023 */ private org.apache.spark.sql.execution.metric.SQLMetric agg_numOutputRows; /* 024 */ private org.apache.spark.sql.execution.metric.SQLMetric agg_aggTime; /* 025 */ private UnsafeRow agg_result; /* 026 */ private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder agg_holder; /* 027 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter agg_rowWriter; /* 028 */ /* 029 */ public GeneratedIterator(Object[] references) { /* 030 */ this.references = references; /* 031 */ } /* 032 */ /* 033 */ public void init(int index, scala.collection.Iterator[] inputs) { /* 034 */ partitionIndex = index; /* 035 */ this.inputs = inputs; /* 036 */ agg_initAgg = false; /* 037 */ /* 038 */ this.range_numOutputRows = (org.apache.spark.sql.execution.metric.SQLMetric) references[0]; /* 039 */ this.range_numGeneratedRows = (org.apache.spark.sql.execution.metric.SQLMetric) references[1]; /* 040 */ range_initRange = false; /* 041 */ range_number = 0L; /* 042 */ range_taskContext = TaskContext.get(); /* 043 */ range_inputMetrics = range_taskContext.taskMetrics().inputMetrics(); /* 044 */ range_batchEnd = 0; /* 045 */ range_numElementsTodo = 0L; /* 046 */ range_input = inputs[0]; /* 047 */ range_result = new UnsafeRow(1); /* 048 */ this.range_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(range_result, 0); /* 049 */ this.range_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(range_holder, 1); /* 050 */ this.agg_numOutputRows = (org.apache.spark.sql.execution.metric.SQLMetric) references[2]; /* 051 */ this.agg_aggTime = (org.apache.spark.sql.execution.metric.SQLMetric) references[3]; /* 052 */ agg_result = new UnsafeRow(1); /* 053 */ this.agg_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(agg_result, 0); /* 054 */ this.agg_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(agg_holder, 1); /* 055 */ /* 056 */ } /* 057 */ /* 058 */ private void agg_doAggregateWithoutKey() throws java.io.IOException { /* 059 */ // initialize aggregation buffer /* 060 */ agg_bufIsNull = false; /* 061 */ agg_bufValue = 0L; /* 062 */ /* 063 */ // initialize Range /* 064 */ if (!range_initRange) { /* 065 */ range_initRange = true; /* 066 */ initRange(partitionIndex); /* 067 */ } /* 068 */ /* 069 */ while (true) { /* 070 */ while (range_number != range_batchEnd) { /* 071 */ long range_value = range_number; /* 072 */ range_number += 1L; /* 073 */ /* 074 */ // do aggregate /* 075 */ // common sub-expressions /* 076 */ /* 077 */ // evaluate aggregate function /* 078 */ boolean agg_isNull1 = false; /* 079 */ /* 080 */ long agg_value1 = -1L; /* 081 */ agg_value1 = agg_bufValue + 1L; /* 082 */ // update aggregation buffer /* 083 */ agg_bufIsNull = false; /* 084 */ agg_bufValue = agg_value1; /* 085 */ /* 086 */ if (shouldStop()) return; /* 087 */ } /* 088 */ /* 089 */ if (range_taskContext.isInterrupted()) { /* 090 */ throw new TaskKilledException(); /* 091 */ } /* 092 */ /* 093 */ long range_nextBatchTodo; /* 094 */ if (range_numElementsTodo > 1000L) { /* 095 */ range_nextBatchTodo = 1000L; /* 096 */ range_numElementsTodo -= 1000L; /* 097 */ } else { /* 098 */ range_nextBatchTodo = range_numElementsTodo; /* 099 */ range_numElementsTodo = 0; /* 100 */ if (range_nextBatchTodo == 0) break; /* 101 */ } /* 102 */ range_numOutputRows.add(range_nextBatchTodo); /* 103 */ range_inputMetrics.incRecordsRead(range_nextBatchTodo); /* 104 */ /* 105 */ range_batchEnd += range_nextBatchTodo * 1L; /* 106 */ } /* 107 */ /* 108 */ } /* 109 */ /* 110 */ private void initRange(int idx) { /* 111 */ java.math.BigInteger index = java.math.BigInteger.valueOf(idx); /* 112 */ java.math.BigInteger numSlice = java.math.BigInteger.valueOf(2L); /* 113 */ java.math.BigInteger numElement = java.math.BigInteger.valueOf(10000L); /* 114 */ java.math.BigInteger step = java.math.BigInteger.valueOf(1L); /* 115 */ java.math.BigInteger start = java.math.BigInteger.valueOf(0L); /* 117 */ /* 118 */ java.math.BigInteger st = index.multiply(numElement).divide(numSlice).multiply(step).add(start); /* 119 */ if (st.compareTo(java.math.BigInteger.valueOf(Long.MAX_VALUE)) > 0) { /* 120 */ range_number = Long.MAX_VALUE; /* 121 */ } else if (st.compareTo(java.math.BigInteger.valueOf(Long.MIN_VALUE)) < 0) { /* 122 */ range_number = Long.MIN_VALUE; /* 123 */ } else { /* 124 */ range_number = st.longValue(); /* 125 */ } /* 126 */ range_batchEnd = range_number; /* 127 */ /* 128 */ java.math.BigInteger end = index.add(java.math.BigInteger.ONE).multiply(numElement).divide(numSlice) /* 129 */ .multiply(step).add(start); /* 130 */ if (end.compareTo(java.math.BigInteger.valueOf(Long.MAX_VALUE)) > 0) { /* 131 */ partitionEnd = Long.MAX_VALUE; /* 132 */ } else if (end.compareTo(java.math.BigInteger.valueOf(Long.MIN_VALUE)) < 0) { /* 133 */ partitionEnd = Long.MIN_VALUE; /* 134 */ } else { /* 135 */ partitionEnd = end.longValue(); /* 136 */ } /* 137 */ /* 138 */ java.math.BigInteger startToEnd = java.math.BigInteger.valueOf(partitionEnd).subtract( /* 139 */ java.math.BigInteger.valueOf(range_number)); /* 140 */ range_numElementsTodo = startToEnd.divide(step).longValue(); /* 141 */ if (range_numElementsTodo < 0) { /* 142 */ range_numElementsTodo = 0; /* 143 */ } else if (startToEnd.remainder(step).compareTo(java.math.BigInteger.valueOf(0L)) != 0) { /* 144 */ range_numElementsTodo++; /* 145 */ } /* 146 */ } /* 147 */ /* 148 */ protected void processNext() throws java.io.IOException { /* 149 */ while (!agg_initAgg) { /* 150 */ agg_initAgg = true; /* 151 */ long agg_beforeAgg = System.nanoTime(); /* 152 */ agg_doAggregateWithoutKey(); /* 153 */ agg_aggTime.add((System.nanoTime() - agg_beforeAgg) / 1000000); /* 154 */ /* 155 */ // output the result /* 156 */ /* 157 */ agg_numOutputRows.add(1); /* 158 */ agg_rowWriter.zeroOutNullBytes(); /* 159 */ /* 160 */ if (agg_bufIsNull) { /* 161 */ agg_rowWriter.setNullAt(0); /* 162 */ } else { /* 163 */ agg_rowWriter.write(0, agg_bufValue); /* 164 */ } /* 165 */ append(agg_result); /* 166 */ } /* 167 */ } /* 168 */ } ``` Generated code with this PR ```java /* 005 */ final class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator { /* 006 */ private Object[] references; /* 007 */ private scala.collection.Iterator[] inputs; /* 008 */ private boolean agg_initAgg; /* 009 */ private boolean agg_bufIsNull; /* 010 */ private long agg_bufValue; /* 011 */ private org.apache.spark.sql.execution.metric.SQLMetric range_numOutputRows; /* 012 */ private org.apache.spark.sql.execution.metric.SQLMetric range_numGeneratedRows; /* 013 */ private boolean range_initRange; /* 014 */ private long range_number; /* 015 */ private TaskContext range_taskContext; /* 016 */ private InputMetrics range_inputMetrics; /* 017 */ private long range_batchEnd; /* 018 */ private long range_numElementsTodo; /* 019 */ private scala.collection.Iterator range_input; /* 020 */ private UnsafeRow range_result; /* 021 */ private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder range_holder; /* 022 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter range_rowWriter; /* 023 */ private org.apache.spark.sql.execution.metric.SQLMetric agg_numOutputRows; /* 024 */ private org.apache.spark.sql.execution.metric.SQLMetric agg_aggTime; /* 025 */ private UnsafeRow agg_result; /* 026 */ private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder agg_holder; /* 027 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter agg_rowWriter; /* 028 */ /* 029 */ public GeneratedIterator(Object[] references) { /* 030 */ this.references = references; /* 031 */ } /* 032 */ /* 033 */ public void init(int index, scala.collection.Iterator[] inputs) { /* 034 */ partitionIndex = index; /* 035 */ this.inputs = inputs; /* 036 */ agg_initAgg = false; /* 037 */ /* 038 */ this.range_numOutputRows = (org.apache.spark.sql.execution.metric.SQLMetric) references[0]; /* 039 */ this.range_numGeneratedRows = (org.apache.spark.sql.execution.metric.SQLMetric) references[1]; /* 040 */ range_initRange = false; /* 041 */ range_number = 0L; /* 042 */ range_taskContext = TaskContext.get(); /* 043 */ range_inputMetrics = range_taskContext.taskMetrics().inputMetrics(); /* 044 */ range_batchEnd = 0; /* 045 */ range_numElementsTodo = 0L; /* 046 */ range_input = inputs[0]; /* 047 */ range_result = new UnsafeRow(1); /* 048 */ this.range_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(range_result, 0); /* 049 */ this.range_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(range_holder, 1); /* 050 */ this.agg_numOutputRows = (org.apache.spark.sql.execution.metric.SQLMetric) references[2]; /* 051 */ this.agg_aggTime = (org.apache.spark.sql.execution.metric.SQLMetric) references[3]; /* 052 */ agg_result = new UnsafeRow(1); /* 053 */ this.agg_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(agg_result, 0); /* 054 */ this.agg_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(agg_holder, 1); /* 055 */ /* 056 */ } /* 057 */ /* 058 */ private void agg_doAggregateWithoutKey() throws java.io.IOException { /* 059 */ // initialize aggregation buffer /* 060 */ agg_bufIsNull = false; /* 061 */ agg_bufValue = 0L; /* 062 */ /* 063 */ // initialize Range /* 064 */ if (!range_initRange) { /* 065 */ range_initRange = true; /* 066 */ initRange(partitionIndex); /* 067 */ } /* 068 */ /* 069 */ while (true) { /* 070 */ long range_range = range_batchEnd - range_number; /* 071 */ if (range_range != 0L) { /* 072 */ int range_localEnd = (int)(range_range / 1L); /* 073 */ for (int range_localIdx = 0; range_localIdx < range_localEnd; range_localIdx++) { /* 074 */ long range_value = ((long)range_localIdx * 1L) + range_number; /* 075 */ /* 076 */ // do aggregate /* 077 */ // common sub-expressions /* 078 */ /* 079 */ // evaluate aggregate function /* 080 */ boolean agg_isNull1 = false; /* 081 */ /* 082 */ long agg_value1 = -1L; /* 083 */ agg_value1 = agg_bufValue + 1L; /* 084 */ // update aggregation buffer /* 085 */ agg_bufIsNull = false; /* 086 */ agg_bufValue = agg_value1; /* 087 */ /* 088 */ // shouldStop check is eliminated /* 089 */ } /* 090 */ range_number = range_batchEnd; /* 091 */ } /* 092 */ /* 093 */ if (range_taskContext.isInterrupted()) { /* 094 */ throw new TaskKilledException(); /* 095 */ } /* 096 */ /* 097 */ long range_nextBatchTodo; /* 098 */ if (range_numElementsTodo > 1000L) { /* 099 */ range_nextBatchTodo = 1000L; /* 100 */ range_numElementsTodo -= 1000L; /* 101 */ } else { /* 102 */ range_nextBatchTodo = range_numElementsTodo; /* 103 */ range_numElementsTodo = 0; /* 104 */ if (range_nextBatchTodo == 0) break; /* 105 */ } /* 106 */ range_numOutputRows.add(range_nextBatchTodo); /* 107 */ range_inputMetrics.incRecordsRead(range_nextBatchTodo); /* 108 */ /* 109 */ range_batchEnd += range_nextBatchTodo * 1L; /* 110 */ } /* 111 */ /* 112 */ } /* 113 */ /* 114 */ private void initRange(int idx) { /* 115 */ java.math.BigInteger index = java.math.BigInteger.valueOf(idx); /* 116 */ java.math.BigInteger numSlice = java.math.BigInteger.valueOf(2L); /* 117 */ java.math.BigInteger numElement = java.math.BigInteger.valueOf(10000L); /* 118 */ java.math.BigInteger step = java.math.BigInteger.valueOf(1L); /* 119 */ java.math.BigInteger start = java.math.BigInteger.valueOf(0L); /* 120 */ long partitionEnd; /* 121 */ /* 122 */ java.math.BigInteger st = index.multiply(numElement).divide(numSlice).multiply(step).add(start); /* 123 */ if (st.compareTo(java.math.BigInteger.valueOf(Long.MAX_VALUE)) > 0) { /* 124 */ range_number = Long.MAX_VALUE; /* 125 */ } else if (st.compareTo(java.math.BigInteger.valueOf(Long.MIN_VALUE)) < 0) { /* 126 */ range_number = Long.MIN_VALUE; /* 127 */ } else { /* 128 */ range_number = st.longValue(); /* 129 */ } /* 130 */ range_batchEnd = range_number; /* 131 */ /* 132 */ java.math.BigInteger end = index.add(java.math.BigInteger.ONE).multiply(numElement).divide(numSlice) /* 133 */ .multiply(step).add(start); /* 134 */ if (end.compareTo(java.math.BigInteger.valueOf(Long.MAX_VALUE)) > 0) { /* 135 */ partitionEnd = Long.MAX_VALUE; /* 136 */ } else if (end.compareTo(java.math.BigInteger.valueOf(Long.MIN_VALUE)) < 0) { /* 137 */ partitionEnd = Long.MIN_VALUE; /* 138 */ } else { /* 139 */ partitionEnd = end.longValue(); /* 140 */ } /* 141 */ /* 142 */ java.math.BigInteger startToEnd = java.math.BigInteger.valueOf(partitionEnd).subtract( /* 143 */ java.math.BigInteger.valueOf(range_number)); /* 144 */ range_numElementsTodo = startToEnd.divide(step).longValue(); /* 145 */ if (range_numElementsTodo < 0) { /* 146 */ range_numElementsTodo = 0; /* 147 */ } else if (startToEnd.remainder(step).compareTo(java.math.BigInteger.valueOf(0L)) != 0) { /* 148 */ range_numElementsTodo++; /* 149 */ } /* 150 */ } /* 151 */ /* 152 */ protected void processNext() throws java.io.IOException { /* 153 */ while (!agg_initAgg) { /* 154 */ agg_initAgg = true; /* 155 */ long agg_beforeAgg = System.nanoTime(); /* 156 */ agg_doAggregateWithoutKey(); /* 157 */ agg_aggTime.add((System.nanoTime() - agg_beforeAgg) / 1000000); /* 158 */ /* 159 */ // output the result /* 160 */ /* 161 */ agg_numOutputRows.add(1); /* 162 */ agg_rowWriter.zeroOutNullBytes(); /* 163 */ /* 164 */ if (agg_bufIsNull) { /* 165 */ agg_rowWriter.setNullAt(0); /* 166 */ } else { /* 167 */ agg_rowWriter.write(0, agg_bufValue); /* 168 */ } /* 169 */ append(agg_result); /* 170 */ } /* 171 */ } /* 172 */ } ``` A part of suppressing `shouldStop()` was originally developed by inouehrs ## How was this patch tested? Add new tests into `DataFrameRangeSuite` Author: Kazuaki Ishizaki Closes #17122 from kiszk/SPARK-19786. --- .../apache/spark/sql/execution/SortExec.scala | 2 ++ .../sql/execution/WholeStageCodegenExec.scala | 15 +++++++++++ .../aggregate/HashAggregateExec.scala | 2 ++ .../execution/basicPhysicalOperators.scala | 27 ++++++++++++++----- .../spark/sql/DataFrameRangeSuite.scala | 16 +++++++++++ 5 files changed, 55 insertions(+), 7 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala index cc576bbc4c80..f98ae82574d2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala @@ -177,6 +177,8 @@ case class SortExec( """.stripMargin.trim } + protected override val shouldStopRequired = false + override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { s""" |${row.code} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index c58474eba05d..c31fd92447c0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -206,6 +206,21 @@ trait CodegenSupport extends SparkPlan { def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { throw new UnsupportedOperationException } + + /** + * For optimization to suppress shouldStop() in a loop of WholeStageCodegen. + * Returning true means we need to insert shouldStop() into the loop producing rows, if any. + */ + def isShouldStopRequired: Boolean = { + return shouldStopRequired && (this.parent == null || this.parent.isShouldStopRequired) + } + + /** + * Set to false if this plan consumes all rows produced by children but doesn't output row + * to buffer by calling append(), so the children don't require shouldStop() + * in the loop of producing rows. + */ + protected def shouldStopRequired: Boolean = true } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 4529ed067e56..68c8e6ce62cb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -238,6 +238,8 @@ case class HashAggregateExec( """.stripMargin } + protected override val shouldStopRequired = false + private def doConsumeWithoutKeys(ctx: CodegenContext, input: Seq[ExprCode]): String = { // only have DeclarativeAggregate val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate]) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index 87e90ed685cc..d876688a8aab 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -387,8 +387,8 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) // How many values should be generated in the next batch. val nextBatchTodo = ctx.freshName("nextBatchTodo") - // The default size of a batch. - val batchSize = 1000L + // The default size of a batch, which must be positive integer + val batchSize = 1000 ctx.addNewFunction("initRange", s""" @@ -434,6 +434,15 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) val input = ctx.freshName("input") // Right now, Range is only used when there is one upstream. ctx.addMutableState("scala.collection.Iterator", input, s"$input = inputs[0];") + + val localIdx = ctx.freshName("localIdx") + val localEnd = ctx.freshName("localEnd") + val range = ctx.freshName("range") + val shouldStop = if (isShouldStopRequired) { + s"if (shouldStop()) { $number = $value + ${step}L; return; }" + } else { + "// shouldStop check is eliminated" + } s""" | // initialize Range | if (!$initTerm) { @@ -442,11 +451,15 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) | } | | while (true) { - | while ($number != $batchEnd) { - | long $value = $number; - | $number += ${step}L; - | ${consume(ctx, Seq(ev))} - | if (shouldStop()) return; + | long $range = $batchEnd - $number; + | if ($range != 0L) { + | int $localEnd = (int)($range / ${step}L); + | for (int $localIdx = 0; $localIdx < $localEnd; $localIdx++) { + | long $value = ((long)$localIdx * ${step}L) + $number; + | ${consume(ctx, Seq(ev))} + | $shouldStop + | } + | $number = $batchEnd; | } | | if ($taskContext.isInterrupted()) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala index acf393a9b0fa..5e323c02b253 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala @@ -89,6 +89,22 @@ class DataFrameRangeSuite extends QueryTest with SharedSQLContext with Eventuall val n = 9L * 1000 * 1000 * 1000 * 1000 * 1000 * 1000 val res13 = spark.range(-n, n, n / 9).select("id") assert(res13.count == 18) + + // range with non aggregation operation + val res14 = spark.range(0, 100, 2).toDF.filter("50 <= id") + val len14 = res14.collect.length + assert(len14 == 25) + + val res15 = spark.range(100, -100, -2).toDF.filter("id <= 0") + val len15 = res15.collect.length + assert(len15 == 50) + + val res16 = spark.range(-1500, 1500, 3).toDF.filter("0 <= id") + val len16 = res16.collect.length + assert(len16 == 500) + + val res17 = spark.range(10, 0, -1, 1).toDF.sortWithinPartitions("id") + assert(res17.collect === (1 to 10).map(i => Row(i)).toArray) } test("Range with randomized parameters") { From dd9049e0492cc70b629518fee9b3d1632374c612 Mon Sep 17 00:00:00 2001 From: Carson Wang Date: Fri, 10 Mar 2017 11:13:26 -0800 Subject: [PATCH 005/512] [SPARK-19620][SQL] Fix incorrect exchange coordinator id in the physical plan ## What changes were proposed in this pull request? When adaptive execution is enabled, an exchange coordinator is used in the Exchange operators. For Join, the same exchange coordinator is used for its two Exchanges. But the physical plan shows two different coordinator Ids which is confusing. This PR is to fix the incorrect exchange coordinator id in the physical plan. The coordinator object instead of the `Option[ExchangeCoordinator]` should be used to generate the identity hash code of the same coordinator. ## How was this patch tested? Before the patch, the physical plan shows two different exchange coordinator id for Join. ``` == Physical Plan == *Project [key1#3L, value2#12L] +- *SortMergeJoin [key1#3L], [key2#11L], Inner :- *Sort [key1#3L ASC NULLS FIRST], false, 0 : +- Exchange(coordinator id: 1804587700) hashpartitioning(key1#3L, 10), coordinator[target post-shuffle partition size: 67108864] : +- *Project [(id#0L % 500) AS key1#3L] : +- *Filter isnotnull((id#0L % 500)) : +- *Range (0, 1000, step=1, splits=Some(10)) +- *Sort [key2#11L ASC NULLS FIRST], false, 0 +- Exchange(coordinator id: 793927319) hashpartitioning(key2#11L, 10), coordinator[target post-shuffle partition size: 67108864] +- *Project [(id#8L % 500) AS key2#11L, id#8L AS value2#12L] +- *Filter isnotnull((id#8L % 500)) +- *Range (0, 1000, step=1, splits=Some(10)) ``` After the patch, two exchange coordinator id are the same. Author: Carson Wang Closes #16952 from carsonwang/FixCoordinatorId. --- .../apache/spark/sql/execution/exchange/ShuffleExchange.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala index 125a4930c652..f06544ea8ed0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala @@ -46,7 +46,7 @@ case class ShuffleExchange( override def nodeName: String = { val extraInfo = coordinator match { case Some(exchangeCoordinator) => - s"(coordinator id: ${System.identityHashCode(coordinator)})" + s"(coordinator id: ${System.identityHashCode(exchangeCoordinator)})" case None => "" } From 8f0490e22b4c7f1fdf381c70c5894d46b7f7e6fb Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Fri, 10 Mar 2017 13:33:58 -0800 Subject: [PATCH 006/512] [SPARK-17979][SPARK-14453] Remove deprecated SPARK_YARN_USER_ENV and SPARK_JAVA_OPTS This fix removes deprecated support for config `SPARK_YARN_USER_ENV`, as is mentioned in SPARK-17979. This fix also removes deprecated support for the following: ``` SPARK_YARN_USER_ENV SPARK_JAVA_OPTS SPARK_CLASSPATH SPARK_WORKER_INSTANCES ``` Related JIRA: [SPARK-14453]: https://issues.apache.org/jira/browse/SPARK-14453 [SPARK-12344]: https://issues.apache.org/jira/browse/SPARK-12344 [SPARK-15781]: https://issues.apache.org/jira/browse/SPARK-15781 Existing tests should pass. Author: Yong Tang Closes #17212 from yongtang/SPARK-17979. --- conf/spark-env.sh.template | 3 - .../scala/org/apache/spark/SparkConf.scala | 65 ------------------- .../spark/deploy/FaultToleranceTest.scala | 3 +- .../spark/launcher/WorkerCommandBuilder.scala | 1 - docs/rdd-programming-guide.md | 2 +- .../launcher/AbstractCommandBuilder.java | 1 - .../launcher/SparkClassCommandBuilder.java | 2 - .../launcher/SparkSubmitCommandBuilder.java | 1 - .../MesosCoarseGrainedSchedulerBackend.scala | 5 -- .../MesosFineGrainedSchedulerBackend.scala | 4 -- .../org/apache/spark/deploy/yarn/Client.scala | 39 +---------- .../spark/deploy/yarn/ExecutorRunnable.scala | 8 --- 12 files changed, 3 insertions(+), 131 deletions(-) diff --git a/conf/spark-env.sh.template b/conf/spark-env.sh.template index 5c1e876ef9af..94bd2c477a35 100755 --- a/conf/spark-env.sh.template +++ b/conf/spark-env.sh.template @@ -25,12 +25,10 @@ # - HADOOP_CONF_DIR, to point Spark towards Hadoop configuration files # - SPARK_LOCAL_IP, to set the IP address Spark binds to on this node # - SPARK_PUBLIC_DNS, to set the public dns name of the driver program -# - SPARK_CLASSPATH, default classpath entries to append # Options read by executors and drivers running inside the cluster # - SPARK_LOCAL_IP, to set the IP address Spark binds to on this node # - SPARK_PUBLIC_DNS, to set the public DNS name of the driver program -# - SPARK_CLASSPATH, default classpath entries to append # - SPARK_LOCAL_DIRS, storage directories to use on this node for shuffle and RDD data # - MESOS_NATIVE_JAVA_LIBRARY, to point to your libmesos.so if you use Mesos @@ -48,7 +46,6 @@ # - SPARK_WORKER_CORES, to set the number of cores to use on this machine # - SPARK_WORKER_MEMORY, to set how much total memory workers have to give executors (e.g. 1000m, 2g) # - SPARK_WORKER_PORT / SPARK_WORKER_WEBUI_PORT, to use non-default ports for the worker -# - SPARK_WORKER_INSTANCES, to set the number of worker processes per node # - SPARK_WORKER_DIR, to set the working directory of worker processes # - SPARK_WORKER_OPTS, to set config properties only for the worker (e.g. "-Dx=y") # - SPARK_DAEMON_MEMORY, to allocate to the master, worker and history server themselves (default: 1g). diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index fe912e639bcb..2a2ce0504dbb 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -518,71 +518,6 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging with Seria } } - // Check for legacy configs - sys.env.get("SPARK_JAVA_OPTS").foreach { value => - val warning = - s""" - |SPARK_JAVA_OPTS was detected (set to '$value'). - |This is deprecated in Spark 1.0+. - | - |Please instead use: - | - ./spark-submit with conf/spark-defaults.conf to set defaults for an application - | - ./spark-submit with --driver-java-options to set -X options for a driver - | - spark.executor.extraJavaOptions to set -X options for executors - | - SPARK_DAEMON_JAVA_OPTS to set java options for standalone daemons (master or worker) - """.stripMargin - logWarning(warning) - - for (key <- Seq(executorOptsKey, driverOptsKey)) { - if (getOption(key).isDefined) { - throw new SparkException(s"Found both $key and SPARK_JAVA_OPTS. Use only the former.") - } else { - logWarning(s"Setting '$key' to '$value' as a work-around.") - set(key, value) - } - } - } - - sys.env.get("SPARK_CLASSPATH").foreach { value => - val warning = - s""" - |SPARK_CLASSPATH was detected (set to '$value'). - |This is deprecated in Spark 1.0+. - | - |Please instead use: - | - ./spark-submit with --driver-class-path to augment the driver classpath - | - spark.executor.extraClassPath to augment the executor classpath - """.stripMargin - logWarning(warning) - - for (key <- Seq(executorClasspathKey, driverClassPathKey)) { - if (getOption(key).isDefined) { - throw new SparkException(s"Found both $key and SPARK_CLASSPATH. Use only the former.") - } else { - logWarning(s"Setting '$key' to '$value' as a work-around.") - set(key, value) - } - } - } - - if (!contains(sparkExecutorInstances)) { - sys.env.get("SPARK_WORKER_INSTANCES").foreach { value => - val warning = - s""" - |SPARK_WORKER_INSTANCES was detected (set to '$value'). - |This is deprecated in Spark 1.0+. - | - |Please instead use: - | - ./spark-submit with --num-executors to specify the number of executors - | - Or set SPARK_EXECUTOR_INSTANCES - | - spark.executor.instances to configure the number of instances in the spark config. - """.stripMargin - logWarning(warning) - - set("spark.executor.instances", value) - } - } - if (contains("spark.master") && get("spark.master").startsWith("yarn-")) { val warning = s"spark.master ${get("spark.master")} is deprecated in Spark 2.0+, please " + "instead use \"yarn\" with specified deploy mode." diff --git a/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala b/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala index 320af5cf9755..c6307da61c7e 100644 --- a/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala +++ b/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala @@ -43,8 +43,7 @@ import org.apache.spark.util.{ThreadUtils, Utils} * Execute using * ./bin/spark-class org.apache.spark.deploy.FaultToleranceTest * - * Make sure that the environment includes the following properties in SPARK_DAEMON_JAVA_OPTS - * *and* SPARK_JAVA_OPTS: + * Make sure that the environment includes the following properties in SPARK_DAEMON_JAVA_OPTS: * - spark.deploy.recoveryMode=ZOOKEEPER * - spark.deploy.zookeeper.url=172.17.42.1:2181 * Note that 172.17.42.1 is the default docker ip for the host and 2181 is the default ZK port. diff --git a/core/src/main/scala/org/apache/spark/launcher/WorkerCommandBuilder.scala b/core/src/main/scala/org/apache/spark/launcher/WorkerCommandBuilder.scala index 3fd812e9fcfe..4216b2627309 100644 --- a/core/src/main/scala/org/apache/spark/launcher/WorkerCommandBuilder.scala +++ b/core/src/main/scala/org/apache/spark/launcher/WorkerCommandBuilder.scala @@ -39,7 +39,6 @@ private[spark] class WorkerCommandBuilder(sparkHome: String, memoryMb: Int, comm val cmd = buildJavaCommand(command.classPathEntries.mkString(File.pathSeparator)) cmd.add(s"-Xmx${memoryMb}M") command.javaOpts.foreach(cmd.add) - addOptionString(cmd, getenv("SPARK_JAVA_OPTS")) cmd } diff --git a/docs/rdd-programming-guide.md b/docs/rdd-programming-guide.md index cad9ff4e646e..e2bf2d7ca77c 100644 --- a/docs/rdd-programming-guide.md +++ b/docs/rdd-programming-guide.md @@ -457,7 +457,7 @@ If required, a Hadoop configuration can be passed in as a Python dict. Here is a Elasticsearch ESInputFormat: {% highlight python %} -$ SPARK_CLASSPATH=/path/to/elasticsearch-hadoop.jar ./bin/pyspark +$ ./bin/pyspark --jars /path/to/elasticsearch-hadoop.jar >>> conf = {"es.resource" : "index/type"} # assume Elasticsearch is running on localhost defaults >>> rdd = sc.newAPIHadoopRDD("org.elasticsearch.hadoop.mr.EsInputFormat", "org.apache.hadoop.io.NullWritable", diff --git a/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java index bc8d6037a367..6c0c3ebcaebf 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java +++ b/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java @@ -135,7 +135,6 @@ List buildClassPath(String appClassPath) throws IOException { String sparkHome = getSparkHome(); Set cp = new LinkedHashSet<>(); - addToClassPath(cp, getenv("SPARK_CLASSPATH")); addToClassPath(cp, appClassPath); addToClassPath(cp, getConfDir()); diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkClassCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/SparkClassCommandBuilder.java index 81786841de22..7cf5b7379503 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/SparkClassCommandBuilder.java +++ b/launcher/src/main/java/org/apache/spark/launcher/SparkClassCommandBuilder.java @@ -66,7 +66,6 @@ public List buildCommand(Map env) memKey = "SPARK_DAEMON_MEMORY"; break; case "org.apache.spark.executor.CoarseGrainedExecutorBackend": - javaOptsKeys.add("SPARK_JAVA_OPTS"); javaOptsKeys.add("SPARK_EXECUTOR_OPTS"); memKey = "SPARK_EXECUTOR_MEMORY"; break; @@ -84,7 +83,6 @@ public List buildCommand(Map env) memKey = "SPARK_DAEMON_MEMORY"; break; default: - javaOptsKeys.add("SPARK_JAVA_OPTS"); memKey = "SPARK_DRIVER_MEMORY"; break; } diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java index 5e64fa7ed152..5f2da036ff9f 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java +++ b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java @@ -240,7 +240,6 @@ private List buildSparkSubmitCommand(Map env) addOptionString(cmd, System.getenv("SPARK_DAEMON_JAVA_OPTS")); } addOptionString(cmd, System.getenv("SPARK_SUBMIT_OPTS")); - addOptionString(cmd, System.getenv("SPARK_JAVA_OPTS")); // We don't want the client to specify Xmx. These have to be set by their corresponding // memory flag --driver-memory or configuration entry spark.driver.memory diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala index 85c2e9c76f4b..c049a32eabf9 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala @@ -175,11 +175,6 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( def createCommand(offer: Offer, numCores: Int, taskId: String): CommandInfo = { val environment = Environment.newBuilder() - val extraClassPath = conf.getOption("spark.executor.extraClassPath") - extraClassPath.foreach { cp => - environment.addVariables( - Environment.Variable.newBuilder().setName("SPARK_CLASSPATH").setValue(cp).build()) - } val extraJavaOpts = conf.get("spark.executor.extraJavaOptions", "") // Set the environment variable through a command prefix diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala index 215271302ec5..f198f8893b3d 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala @@ -106,10 +106,6 @@ private[spark] class MesosFineGrainedSchedulerBackend( throw new SparkException("Executor Spark home `spark.mesos.executor.home` is not set!") } val environment = Environment.newBuilder() - sc.conf.getOption("spark.executor.extraClassPath").foreach { cp => - environment.addVariables( - Environment.Variable.newBuilder().setName("SPARK_CLASSPATH").setValue(cp).build()) - } val extraJavaOpts = sc.conf.getOption("spark.executor.extraJavaOptions").getOrElse("") val prefixEnv = sc.conf.getOption("spark.executor.extraLibraryPath").map { p => diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index e86bd5459311..ccb0f8fdbbc2 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -748,14 +748,6 @@ private[spark] class Client( .map { case (k, v) => (k.substring(amEnvPrefix.length), v) } .foreach { case (k, v) => YarnSparkHadoopUtil.addPathToEnvironment(env, k, v) } - // Keep this for backwards compatibility but users should move to the config - sys.env.get("SPARK_YARN_USER_ENV").foreach { userEnvs => - // Allow users to specify some environment variables. - YarnSparkHadoopUtil.setEnvFromInputString(env, userEnvs) - // Pass SPARK_YARN_USER_ENV itself to the AM so it can use it to set up executor environments. - env("SPARK_YARN_USER_ENV") = userEnvs - } - // If pyFiles contains any .py files, we need to add LOCALIZED_PYTHON_DIR to the PYTHONPATH // of the container processes too. Add all non-.py files directly to PYTHONPATH. // @@ -782,35 +774,7 @@ private[spark] class Client( sparkConf.setExecutorEnv("PYTHONPATH", pythonPathStr) } - // In cluster mode, if the deprecated SPARK_JAVA_OPTS is set, we need to propagate it to - // executors. But we can't just set spark.executor.extraJavaOptions, because the driver's - // SparkContext will not let that set spark* system properties, which is expected behavior for - // Yarn clients. So propagate it through the environment. - // - // Note that to warn the user about the deprecation in cluster mode, some code from - // SparkConf#validateSettings() is duplicated here (to avoid triggering the condition - // described above). if (isClusterMode) { - sys.env.get("SPARK_JAVA_OPTS").foreach { value => - val warning = - s""" - |SPARK_JAVA_OPTS was detected (set to '$value'). - |This is deprecated in Spark 1.0+. - | - |Please instead use: - | - ./spark-submit with conf/spark-defaults.conf to set defaults for an application - | - ./spark-submit with --driver-java-options to set -X options for a driver - | - spark.executor.extraJavaOptions to set -X options for executors - """.stripMargin - logWarning(warning) - for (proc <- Seq("driver", "executor")) { - val key = s"spark.$proc.extraJavaOptions" - if (sparkConf.contains(key)) { - throw new SparkException(s"Found both $key and SPARK_JAVA_OPTS. Use only the former.") - } - } - env("SPARK_JAVA_OPTS") = value - } // propagate PYSPARK_DRIVER_PYTHON and PYSPARK_PYTHON to driver in cluster mode Seq("PYSPARK_DRIVER_PYTHON", "PYSPARK_PYTHON").foreach { envname => if (!env.contains(envname)) { @@ -883,8 +847,7 @@ private[spark] class Client( // Include driver-specific java options if we are launching a driver if (isClusterMode) { - val driverOpts = sparkConf.get(DRIVER_JAVA_OPTIONS).orElse(sys.env.get("SPARK_JAVA_OPTS")) - driverOpts.foreach { opts => + sparkConf.get(DRIVER_JAVA_OPTIONS).foreach { opts => javaOpts ++= Utils.splitCommandString(opts).map(YarnSparkHadoopUtil.escapeForShell) } val libraryPaths = Seq(sparkConf.get(DRIVER_LIBRARY_PATH), diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala index ee85c043b8bc..3f4d236571ff 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala @@ -143,9 +143,6 @@ private[yarn] class ExecutorRunnable( sparkConf.get(EXECUTOR_JAVA_OPTIONS).foreach { opts => javaOpts ++= Utils.splitCommandString(opts).map(YarnSparkHadoopUtil.escapeForShell) } - sys.env.get("SPARK_JAVA_OPTS").foreach { opts => - javaOpts ++= Utils.splitCommandString(opts).map(YarnSparkHadoopUtil.escapeForShell) - } sparkConf.get(EXECUTOR_LIBRARY_PATH).foreach { p => prefixEnv = Some(Client.getClusterPath(sparkConf, Utils.libraryPathEnvPrefix(Seq(p)))) } @@ -229,11 +226,6 @@ private[yarn] class ExecutorRunnable( YarnSparkHadoopUtil.addPathToEnvironment(env, key, value) } - // Keep this for backwards compatibility but users should move to the config - sys.env.get("SPARK_YARN_USER_ENV").foreach { userEnvs => - YarnSparkHadoopUtil.setEnvFromInputString(env, userEnvs) - } - // lookup appropriate http scheme for container log urls val yarnHttpPolicy = conf.get( YarnConfiguration.YARN_HTTP_POLICY_KEY, From bc30351404d8bc610cbae65fdc12ca613e7735c6 Mon Sep 17 00:00:00 2001 From: Budde Date: Fri, 10 Mar 2017 15:18:37 -0800 Subject: [PATCH 007/512] [SPARK-19611][SQL] Preserve metastore field order when merging inferred schema ## What changes were proposed in this pull request? The ```HiveMetastoreCatalog.mergeWithMetastoreSchema()``` method added in #16944 may not preserve the same field order as the metastore schema in some cases, which can cause queries to fail. This change ensures that the metastore field order is preserved. ## How was this patch tested? A test for ensuring that metastore order is preserved was added to ```HiveSchemaInferenceSuite.``` The particular failure usecase from #16944 was tested manually as well. Author: Budde Closes #17249 from budde/PreserveMetastoreFieldOrder. --- .../spark/sql/hive/HiveMetastoreCatalog.scala | 5 +---- .../sql/hive/HiveSchemaInferenceSuite.scala | 21 +++++++++++++++++++ 2 files changed, 22 insertions(+), 4 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 056af495590f..9f0d1ceb28fc 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -356,13 +356,10 @@ private[hive] object HiveMetastoreCatalog { .filterKeys(!inferredSchema.map(_.name.toLowerCase).contains(_)) .values .filter(_.nullable) - // Merge missing nullable fields to inferred schema and build a case-insensitive field map. val inferredFields = StructType(inferredSchema ++ missingNullables) .map(f => f.name.toLowerCase -> f).toMap - StructType(metastoreFields.map { case(name, field) => - field.copy(name = inferredFields(name).name) - }.toSeq) + StructType(metastoreSchema.map(f => f.copy(name = inferredFields(f.name).name))) } catch { case NonFatal(_) => val msg = s"""Detected conflicting schemas when merging the schema obtained from the Hive diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSchemaInferenceSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSchemaInferenceSuite.scala index 78955803819c..e48ce2304d08 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSchemaInferenceSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSchemaInferenceSuite.scala @@ -293,6 +293,27 @@ class HiveSchemaInferenceSuite StructField("firstField", StringType, nullable = true), StructField("secondField", StringType, nullable = true)))) }.getMessage.contains("Detected conflicting schemas")) + + // Schema merge should maintain metastore order. + assertResult( + StructType(Seq( + StructField("first_field", StringType, nullable = true), + StructField("second_field", StringType, nullable = true), + StructField("third_field", StringType, nullable = true), + StructField("fourth_field", StringType, nullable = true), + StructField("fifth_field", StringType, nullable = true)))) { + HiveMetastoreCatalog.mergeWithMetastoreSchema( + StructType(Seq( + StructField("first_field", StringType, nullable = true), + StructField("second_field", StringType, nullable = true), + StructField("third_field", StringType, nullable = true), + StructField("fourth_field", StringType, nullable = true), + StructField("fifth_field", StringType, nullable = true))), + StructType(Seq( + StructField("fifth_field", StringType, nullable = true), + StructField("third_field", StringType, nullable = true), + StructField("second_field", StringType, nullable = true)))) + } } } From ffee4f1cefb0dfd8d9145ee3be82c6f7b799870b Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Fri, 10 Mar 2017 15:19:32 -0800 Subject: [PATCH 008/512] [SPARK-19905][SQL] Bring back Dataset.inputFiles for Hive SerDe tables ## What changes were proposed in this pull request? `Dataset.inputFiles` works by matching `FileRelation`s in the query plan. In Spark 2.1, Hive SerDe tables are represented by `MetastoreRelation`, which inherits from `FileRelation`. However, in Spark 2.2, Hive SerDe tables are now represented by `CatalogRelation`, which doesn't inherit from `FileRelation` anymore, due to the unification of Hive SerDe tables and data source tables. This change breaks `Dataset.inputFiles` for Hive SerDe tables. This PR tries to fix this issue by explicitly matching `CatalogRelation`s that are Hive SerDe tables in `Dataset.inputFiles`. Note that we can't make `CatalogRelation` inherit from `FileRelation` since not all `CatalogRelation`s are file based (e.g., JDBC data source tables). ## How was this patch tested? New test case added in `HiveDDLSuite`. Author: Cheng Lian Closes #17247 from liancheng/spark-19905-hive-table-input-files. --- .../src/main/scala/org/apache/spark/sql/Dataset.scala | 3 +++ .../spark/sql/hive/execution/HiveDDLSuite.scala | 11 +++++++++++ 2 files changed, 14 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 0a4d3a93a07e..520663f62440 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -36,6 +36,7 @@ import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst._ import org.apache.spark.sql.catalyst.analysis._ +import org.apache.spark.sql.catalyst.catalog.CatalogRelation import org.apache.spark.sql.catalyst.encoders._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ @@ -2734,6 +2735,8 @@ class Dataset[T] private[sql]( fsBasedRelation.inputFiles case fr: FileRelation => fr.inputFiles + case r: CatalogRelation if DDLUtils.isHiveTable(r.tableMeta) => + r.tableMeta.storage.locationUri.map(_.toString).toArray }.flatten files.toSet.toArray } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala index 23aea2469778..79ad156c5561 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -1865,4 +1865,15 @@ class HiveDDLSuite } } } + + test("SPARK-19905: Hive SerDe table input paths") { + withTable("spark_19905") { + withTempView("spark_19905_view") { + spark.range(10).createOrReplaceTempView("spark_19905_view") + sql("CREATE TABLE spark_19905 STORED AS RCFILE AS SELECT * FROM spark_19905_view") + assert(spark.table("spark_19905").inputFiles.nonEmpty) + assert(sql("SELECT input_file_name() FROM spark_19905").count() > 0) + } + } + } } From fb9beda54622e0c3190c6504fc468fa4e50eeb45 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 10 Mar 2017 16:14:22 -0800 Subject: [PATCH 009/512] [SPARK-19893][SQL] should not run DataFrame set oprations with map type ## What changes were proposed in this pull request? In spark SQL, map type can't be used in equality test/comparison, and `Intersect`/`Except`/`Distinct` do need equality test for all columns, we should not allow map type in `Intersect`/`Except`/`Distinct`. ## How was this patch tested? new regression test Author: Wenchen Fan Closes #17236 from cloud-fan/map. --- .../sql/catalyst/analysis/CheckAnalysis.scala | 25 ++++++++++++++++--- .../org/apache/spark/sql/DataFrameSuite.scala | 19 ++++++++++++++ .../columnar/InMemoryColumnarQuerySuite.scala | 14 +++++------ 3 files changed, 47 insertions(+), 11 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 7529f9028498..d32fbeb4e91e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -44,6 +44,18 @@ trait CheckAnalysis extends PredicateHelper { }).length > 1 } + protected def hasMapType(dt: DataType): Boolean = { + dt.existsRecursively(_.isInstanceOf[MapType]) + } + + protected def mapColumnInSetOperation(plan: LogicalPlan): Option[Attribute] = plan match { + case _: Intersect | _: Except | _: Distinct => + plan.output.find(a => hasMapType(a.dataType)) + case d: Deduplicate => + d.keys.find(a => hasMapType(a.dataType)) + case _ => None + } + private def checkLimitClause(limitExpr: Expression): Unit = { limitExpr match { case e if !e.foldable => failAnalysis( @@ -121,8 +133,7 @@ trait CheckAnalysis extends PredicateHelper { if (conditions.isEmpty && query.output.size != 1) { failAnalysis( s"Scalar subquery must return only one column, but got ${query.output.size}") - } - else if (conditions.nonEmpty) { + } else if (conditions.nonEmpty) { // Collect the columns from the subquery for further checking. var subqueryColumns = conditions.flatMap(_.references).filter(query.output.contains) @@ -200,7 +211,7 @@ trait CheckAnalysis extends PredicateHelper { s"filter expression '${f.condition.sql}' " + s"of type ${f.condition.dataType.simpleString} is not a boolean.") - case f @ Filter(condition, child) => + case Filter(condition, _) => splitConjunctivePredicates(condition).foreach { case _: PredicateSubquery | Not(_: PredicateSubquery) => case e if PredicateSubquery.hasNullAwarePredicateWithinNot(e) => @@ -374,6 +385,14 @@ trait CheckAnalysis extends PredicateHelper { |Conflicting attributes: ${conflictingAttributes.mkString(",")} """.stripMargin) + // TODO: although map type is not orderable, technically map type should be able to be + // used in equality comparison, remove this type check once we support it. + case o if mapColumnInSetOperation(o).isDefined => + val mapCol = mapColumnInSetOperation(o).get + failAnalysis("Cannot have map type columns in DataFrame which calls " + + s"set operations(intersect, except, etc.), but the type of column ${mapCol.name} " + + "is " + mapCol.dataType.simpleString) + case o if o.expressions.exists(!_.deterministic) && !o.isInstanceOf[Project] && !o.isInstanceOf[Filter] && !o.isInstanceOf[Aggregate] && !o.isInstanceOf[Window] => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 19c2d5532d08..52bd4e19f895 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -1703,4 +1703,23 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { val df = spark.range(1).selectExpr("CAST(id as DECIMAL) as x").selectExpr("percentile(x, 0.5)") checkAnswer(df, Row(BigDecimal(0.0)) :: Nil) } + + test("SPARK-19893: cannot run set operations with map type") { + val df = spark.range(1).select(map(lit("key"), $"id").as("m")) + val e = intercept[AnalysisException](df.intersect(df)) + assert(e.message.contains( + "Cannot have map type columns in DataFrame which calls set operations")) + val e2 = intercept[AnalysisException](df.except(df)) + assert(e2.message.contains( + "Cannot have map type columns in DataFrame which calls set operations")) + val e3 = intercept[AnalysisException](df.distinct()) + assert(e3.message.contains( + "Cannot have map type columns in DataFrame which calls set operations")) + withTempView("v") { + df.createOrReplaceTempView("v") + val e4 = intercept[AnalysisException](sql("SELECT DISTINCT m FROM v")) + assert(e4.message.contains( + "Cannot have map type columns in DataFrame which calls set operations")) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala index f355a5200ce2..0250a53fe232 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala @@ -234,8 +234,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { Seq(StringType, BinaryType, NullType, BooleanType, ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType, DecimalType(25, 5), DecimalType(6, 5), - DateType, TimestampType, - ArrayType(IntegerType), MapType(StringType, LongType), struct) + DateType, TimestampType, ArrayType(IntegerType), struct) val fields = dataTypes.zipWithIndex.map { case (dataType, index) => StructField(s"col$index", dataType, true) } @@ -244,10 +243,10 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { // Create an RDD for the schema val rdd = - sparkContext.parallelize((1 to 10000), 10).map { i => + sparkContext.parallelize(1 to 10000, 10).map { i => Row( - s"str${i}: test cache.", - s"binary${i}: test cache.".getBytes(StandardCharsets.UTF_8), + s"str$i: test cache.", + s"binary$i: test cache.".getBytes(StandardCharsets.UTF_8), null, i % 2 == 0, i.toByte, @@ -255,13 +254,12 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { i, Long.MaxValue - i.toLong, (i + 0.25).toFloat, - (i + 0.75), + i + 0.75, BigDecimal(Long.MaxValue.toString + ".12345"), new java.math.BigDecimal(s"${i % 9 + 1}" + ".23456"), new Date(i), new Timestamp(i * 1000000L), - (i to i + 10).toSeq, - (i to i + 10).map(j => s"map_key_$j" -> (Long.MaxValue - j)).toMap, + i to i + 10, Row((i - 0.25).toFloat, Seq(true, false, null))) } spark.createDataFrame(rdd, schema).createOrReplaceTempView("InMemoryCache_different_data_types") From f6fdf92d0dce2cb3340f3e2ff768e09ef69176cd Mon Sep 17 00:00:00 2001 From: windpiger Date: Fri, 10 Mar 2017 20:59:32 -0800 Subject: [PATCH 010/512] [SPARK-19723][SQL] create datasource table with an non-existent location should work ## What changes were proposed in this pull request? This JIRA is a follow up work after [SPARK-19583](https://issues.apache.org/jira/browse/SPARK-19583) As we discussed in that [PR](https://github.com/apache/spark/pull/16938) The following DDL for datasource table with an non-existent location should work: ``` CREATE TABLE ... (PARTITIONED BY ...) LOCATION path ``` Currently it will throw exception that path not exists for datasource table for datasource table ## How was this patch tested? unit test added Author: windpiger Closes #17055 from windpiger/CTDataSourcePathNotExists. --- .../command/createDataSourceTables.scala | 3 +- .../sql/execution/command/DDLSuite.scala | 106 ++++++++++------- .../sql/hive/execution/HiveDDLSuite.scala | 111 ++++++++---------- 3 files changed, 115 insertions(+), 105 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala index 3da66afceda9..2d890118ae0a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala @@ -73,7 +73,8 @@ case class CreateDataSourceTableCommand(table: CatalogTable, ignoreIfExists: Boo className = table.provider.get, bucketSpec = table.bucketSpec, options = table.storage.properties ++ pathOption, - catalogTable = Some(tableWithDefaultOptions)).resolveRelation() + // As discussed in SPARK-19583, we don't check if the location is existed + catalogTable = Some(tableWithDefaultOptions)).resolveRelation(checkFilesExist = false) val partitionColumnNames = if (table.schema.nonEmpty) { table.partitionColumnNames diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index 5f70a8ce8918..0666f446f3b5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -230,7 +230,7 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } private def getDBPath(dbName: String): URI = { - val warehousePath = makeQualifiedPath(s"${spark.sessionState.conf.warehousePath}") + val warehousePath = makeQualifiedPath(spark.sessionState.conf.warehousePath) new Path(CatalogUtils.URIToString(warehousePath), s"$dbName.db").toUri } @@ -1899,7 +1899,7 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } } - test("insert data to a data source table which has a not existed location should succeed") { + test("insert data to a data source table which has a non-existing location should succeed") { withTable("t") { withTempDir { dir => spark.sql( @@ -1939,7 +1939,7 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } } - test("insert into a data source table with no existed partition location should succeed") { + test("insert into a data source table with a non-existing partition location should succeed") { withTable("t") { withTempDir { dir => spark.sql( @@ -1966,7 +1966,7 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } } - test("read data from a data source table which has a not existed location should succeed") { + test("read data from a data source table which has a non-existing location should succeed") { withTable("t") { withTempDir { dir => spark.sql( @@ -1994,7 +1994,7 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } } - test("read data from a data source table with no existed partition location should succeed") { + test("read data from a data source table with non-existing partition location should succeed") { withTable("t") { withTempDir { dir => spark.sql( @@ -2016,48 +2016,72 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } } + test("create datasource table with a non-existing location") { + withTable("t", "t1") { + withTempPath { dir => + spark.sql(s"CREATE TABLE t(a int, b int) USING parquet LOCATION '$dir'") + + val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) + assert(table.location == makeQualifiedPath(dir.getAbsolutePath)) + + spark.sql("INSERT INTO TABLE t SELECT 1, 2") + assert(dir.exists()) + + checkAnswer(spark.table("t"), Row(1, 2)) + } + // partition table + withTempPath { dir => + spark.sql(s"CREATE TABLE t1(a int, b int) USING parquet PARTITIONED BY(a) LOCATION '$dir'") + + val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t1")) + assert(table.location == makeQualifiedPath(dir.getAbsolutePath)) + + spark.sql("INSERT INTO TABLE t1 PARTITION(a=1) SELECT 2") + + val partDir = new File(dir, "a=1") + assert(partDir.exists()) + + checkAnswer(spark.table("t1"), Row(2, 1)) + } + } + } + Seq(true, false).foreach { shouldDelete => - val tcName = if (shouldDelete) "non-existent" else "existed" + val tcName = if (shouldDelete) "non-existing" else "existed" test(s"CTAS for external data source table with a $tcName location") { withTable("t", "t1") { - withTempDir { - dir => - if (shouldDelete) { - dir.delete() - } - spark.sql( - s""" - |CREATE TABLE t - |USING parquet - |LOCATION '$dir' - |AS SELECT 3 as a, 4 as b, 1 as c, 2 as d - """.stripMargin) - val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) - assert(table.location == makeQualifiedPath(dir.getAbsolutePath)) + withTempDir { dir => + if (shouldDelete) dir.delete() + spark.sql( + s""" + |CREATE TABLE t + |USING parquet + |LOCATION '$dir' + |AS SELECT 3 as a, 4 as b, 1 as c, 2 as d + """.stripMargin) + val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) + assert(table.location == makeQualifiedPath(dir.getAbsolutePath)) - checkAnswer(spark.table("t"), Row(3, 4, 1, 2)) + checkAnswer(spark.table("t"), Row(3, 4, 1, 2)) } // partition table - withTempDir { - dir => - if (shouldDelete) { - dir.delete() - } - spark.sql( - s""" - |CREATE TABLE t1 - |USING parquet - |PARTITIONED BY(a, b) - |LOCATION '$dir' - |AS SELECT 3 as a, 4 as b, 1 as c, 2 as d - """.stripMargin) - val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t1")) - assert(table.location == makeQualifiedPath(dir.getAbsolutePath)) - - val partDir = new File(dir, "a=3") - assert(partDir.exists()) - - checkAnswer(spark.table("t1"), Row(1, 2, 3, 4)) + withTempDir { dir => + if (shouldDelete) dir.delete() + spark.sql( + s""" + |CREATE TABLE t1 + |USING parquet + |PARTITIONED BY(a, b) + |LOCATION '$dir' + |AS SELECT 3 as a, 4 as b, 1 as c, 2 as d + """.stripMargin) + val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t1")) + assert(table.location == makeQualifiedPath(dir.getAbsolutePath)) + + val partDir = new File(dir, "a=3") + assert(partDir.exists()) + + checkAnswer(spark.table("t1"), Row(1, 2, 3, 4)) } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala index 79ad156c5561..d29242bb47e3 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -1663,43 +1663,73 @@ class HiveDDLSuite } } + test("create hive table with a non-existing location") { + withTable("t", "t1") { + withTempPath { dir => + spark.sql(s"CREATE TABLE t(a int, b int) USING hive LOCATION '$dir'") + + val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) + assert(table.location == makeQualifiedPath(dir.getAbsolutePath)) + + spark.sql("INSERT INTO TABLE t SELECT 1, 2") + assert(dir.exists()) + + checkAnswer(spark.table("t"), Row(1, 2)) + } + // partition table + withTempPath { dir => + spark.sql( + s""" + |CREATE TABLE t1(a int, b int) + |USING hive + |PARTITIONED BY(a) + |LOCATION '$dir' + """.stripMargin) + + val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t1")) + assert(table.location == makeQualifiedPath(dir.getAbsolutePath)) + + spark.sql("INSERT INTO TABLE t1 PARTITION(a=1) SELECT 2") + + val partDir = new File(dir, "a=1") + assert(partDir.exists()) + + checkAnswer(spark.table("t1"), Row(2, 1)) + } + } + } + Seq(true, false).foreach { shouldDelete => - val tcName = if (shouldDelete) "non-existent" else "existed" - test(s"CTAS for external data source table with a $tcName location") { + val tcName = if (shouldDelete) "non-existing" else "existed" + + test(s"CTAS for external hive table with a $tcName location") { withTable("t", "t1") { - withTempDir { - dir => - if (shouldDelete) { - dir.delete() - } + withSQLConf("hive.exec.dynamic.partition.mode" -> "nonstrict") { + withTempDir { dir => + if (shouldDelete) dir.delete() spark.sql( s""" |CREATE TABLE t - |USING parquet + |USING hive |LOCATION '$dir' |AS SELECT 3 as a, 4 as b, 1 as c, 2 as d """.stripMargin) - val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) assert(table.location == makeQualifiedPath(dir.getAbsolutePath)) checkAnswer(spark.table("t"), Row(3, 4, 1, 2)) - } - // partition table - withTempDir { - dir => - if (shouldDelete) { - dir.delete() - } + } + // partition table + withTempDir { dir => + if (shouldDelete) dir.delete() spark.sql( s""" |CREATE TABLE t1 - |USING parquet + |USING hive |PARTITIONED BY(a, b) |LOCATION '$dir' |AS SELECT 3 as a, 4 as b, 1 as c, 2 as d """.stripMargin) - val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t1")) assert(table.location == makeQualifiedPath(dir.getAbsolutePath)) @@ -1707,51 +1737,6 @@ class HiveDDLSuite assert(partDir.exists()) checkAnswer(spark.table("t1"), Row(1, 2, 3, 4)) - } - } - } - - test(s"CTAS for external hive table with a $tcName location") { - withTable("t", "t1") { - withSQLConf("hive.exec.dynamic.partition.mode" -> "nonstrict") { - withTempDir { - dir => - if (shouldDelete) { - dir.delete() - } - spark.sql( - s""" - |CREATE TABLE t - |USING hive - |LOCATION '$dir' - |AS SELECT 3 as a, 4 as b, 1 as c, 2 as d - """.stripMargin) - val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) - assert(table.location == makeQualifiedPath(dir.getAbsolutePath)) - - checkAnswer(spark.table("t"), Row(3, 4, 1, 2)) - } - // partition table - withTempDir { - dir => - if (shouldDelete) { - dir.delete() - } - spark.sql( - s""" - |CREATE TABLE t1 - |USING hive - |PARTITIONED BY(a, b) - |LOCATION '$dir' - |AS SELECT 3 as a, 4 as b, 1 as c, 2 as d - """.stripMargin) - val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t1")) - assert(table.location == makeQualifiedPath(dir.getAbsolutePath)) - - val partDir = new File(dir, "a=3") - assert(partDir.exists()) - - checkAnswer(spark.table("t1"), Row(1, 2, 3, 4)) } } } From e29a74d5b1fa3f9356b7af5dd7e3fce49bc8eb7d Mon Sep 17 00:00:00 2001 From: uncleGen Date: Sun, 12 Mar 2017 08:29:37 +0000 Subject: [PATCH 011/512] [DOCS][SS] fix structured streaming python example ## What changes were proposed in this pull request? - SS python example: `TypeError: 'xxx' object is not callable` - some other doc issue. ## How was this patch tested? Jenkins. Author: uncleGen Closes #17257 from uncleGen/docs-ss-python. --- docs/structured-streaming-programming-guide.md | 18 +++++++++--------- .../execution/streaming/FileStreamSource.scala | 2 +- .../streaming/dstream/FileInputDStream.scala | 2 +- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/docs/structured-streaming-programming-guide.md b/docs/structured-streaming-programming-guide.md index 995ac77a4fb3..798847237866 100644 --- a/docs/structured-streaming-programming-guide.md +++ b/docs/structured-streaming-programming-guide.md @@ -539,7 +539,7 @@ spark = SparkSession. ... # Read text from socket socketDF = spark \ - .readStream() \ + .readStream \ .format("socket") \ .option("host", "localhost") \ .option("port", 9999) \ @@ -552,7 +552,7 @@ socketDF.printSchema() # Read all the csv files written atomically in a directory userSchema = StructType().add("name", "string").add("age", "integer") csvDF = spark \ - .readStream() \ + .readStream \ .option("sep", ";") \ .schema(userSchema) \ .csv("/path/to/directory") # Equivalent to format("csv").load("/path/to/directory") @@ -971,7 +971,7 @@ Here is the compatibility matrix.

Update mode uses watermark to drop old aggregation state.

- Complete mode does drop not old aggregation state since by definition this mode + Complete mode does not drop old aggregation state since by definition this mode preserves all data in the Result Table. @@ -1201,13 +1201,13 @@ noAggDF = deviceDataDf.select("device").where("signal > 10") # Print new data to console noAggDF \ - .writeStream() \ + .writeStream \ .format("console") \ .start() # Write new data to Parquet files noAggDF \ - .writeStream() \ + .writeStream \ .format("parquet") \ .option("checkpointLocation", "path/to/checkpoint/dir") \ .option("path", "path/to/destination/dir") \ @@ -1218,14 +1218,14 @@ aggDF = df.groupBy("device").count() # Print updated aggregations to console aggDF \ - .writeStream() \ + .writeStream \ .outputMode("complete") \ .format("console") \ .start() # Have all the aggregates in an in memory table. The query name will be the table name aggDF \ - .writeStream() \ + .writeStream \ .queryName("aggregates") \ .outputMode("complete") \ .format("memory") \ @@ -1313,7 +1313,7 @@ query.lastProgress(); // the most recent progress update of this streaming qu
{% highlight python %} -query = df.writeStream().format("console").start() # get the query object +query = df.writeStream.format("console").start() # get the query object query.id() # get the unique identifier of the running query that persists across restarts from checkpoint data @@ -1658,7 +1658,7 @@ aggDF {% highlight python %} aggDF \ - .writeStream() \ + .writeStream \ .outputMode("complete") \ .option("checkpointLocation", "path/to/HDFS/dir") \ .format("memory") \ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala index 411a15ffceb6..a9e64c640042 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala @@ -97,7 +97,7 @@ class FileStreamSource( } seenFiles.purge() - logInfo(s"maxFilesPerBatch = $maxFilesPerBatch, maxFileAge = $maxFileAgeMs") + logInfo(s"maxFilesPerBatch = $maxFilesPerBatch, maxFileAgeMs = $maxFileAgeMs") /** * Returns the maximum offset that can be retrieved from the source. diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala index ed9305875cb7..905b1c52afa6 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala @@ -230,7 +230,7 @@ class FileInputDStream[K, V, F <: NewInputFormat[K, V]]( * - It must pass the user-provided file filter. * - It must be newer than the ignore threshold. It is assumed that files older than the ignore * threshold have already been considered or are existing files before start - * (when newFileOnly = true). + * (when newFilesOnly = true). * - It must not be present in the recently selected files that this class remembers. * - It must not be newer than the time of the batch (i.e. `currentTime` for which this * file is being tested. This can occur if the driver was recovered, and the missing batches From 2f5187bde1544c452fe5116a2bd243653332a079 Mon Sep 17 00:00:00 2001 From: "xiaojian.fxj" Date: Sun, 12 Mar 2017 10:29:00 -0700 Subject: [PATCH 012/512] [SPARK-19831][CORE] Reuse the existing cleanupThreadExecutor to clean up the directories of finished applications to avoid the block Cleaning the application may cost much time at worker, then it will block that the worker send heartbeats master because the worker is extend ThreadSafeRpcEndpoint. If the heartbeat from a worker is blocked by the message ApplicationFinished, master will think the worker is dead. If the worker has a driver, the driver will be scheduled by master again. It had better reuse the existing cleanupThreadExecutor to clean up the directories of finished applications to avoid the block. Author: xiaojian.fxj Closes #17189 from hustfxj/worker-hearbeat. --- .../org/apache/spark/deploy/worker/Worker.scala | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index e48817ebbafd..00b9d1af373d 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -62,8 +62,8 @@ private[deploy] class Worker( private val forwordMessageScheduler = ThreadUtils.newDaemonSingleThreadScheduledExecutor("worker-forward-message-scheduler") - // A separated thread to clean up the workDir. Used to provide the implicit parameter of `Future` - // methods. + // A separated thread to clean up the workDir and the directories of finished applications. + // Used to provide the implicit parameter of `Future` methods. private val cleanupThreadExecutor = ExecutionContext.fromExecutorService( ThreadUtils.newDaemonSingleThreadExecutor("worker-cleanup-thread")) @@ -578,10 +578,15 @@ private[deploy] class Worker( if (shouldCleanup) { finishedApps -= id appDirectories.remove(id).foreach { dirList => - logInfo(s"Cleaning up local directories for application $id") - dirList.foreach { dir => - Utils.deleteRecursively(new File(dir)) - } + concurrent.Future { + logInfo(s"Cleaning up local directories for application $id") + dirList.foreach { dir => + Utils.deleteRecursively(new File(dir)) + } + }(cleanupThreadExecutor).onFailure { + case e: Throwable => + logError(s"Clean up app dir $dirList failed: ${e.getMessage}", e) + }(cleanupThreadExecutor) } shuffleService.applicationRemoved(id) } From 9f8ce4825e378b6a856ce65cb9986a5a0f0b624e Mon Sep 17 00:00:00 2001 From: Xin Ren Date: Sun, 12 Mar 2017 12:15:19 -0700 Subject: [PATCH 013/512] [SPARK-19282][ML][SPARKR] RandomForest Wrapper and GBT Wrapper return param "maxDepth" to R models ## What changes were proposed in this pull request? RandomForest R Wrapper and GBT R Wrapper return param `maxDepth` to R models. Below 4 R wrappers are changed: * `RandomForestClassificationWrapper` * `RandomForestRegressionWrapper` * `GBTClassificationWrapper` * `GBTRegressionWrapper` ## How was this patch tested? Test manually on my local machine. Author: Xin Ren Closes #17207 from keypointt/SPARK-19282. --- R/pkg/R/mllib_tree.R | 11 +++++++---- R/pkg/inst/tests/testthat/test_mllib_tree.R | 10 ++++++++++ .../apache/spark/ml/r/GBTClassificationWrapper.scala | 1 + .../org/apache/spark/ml/r/GBTRegressionWrapper.scala | 1 + .../ml/r/RandomForestClassificationWrapper.scala | 1 + .../spark/ml/r/RandomForestRegressionWrapper.scala | 1 + 6 files changed, 21 insertions(+), 4 deletions(-) diff --git a/R/pkg/R/mllib_tree.R b/R/pkg/R/mllib_tree.R index 40a806c41bad..82279be6fbe7 100644 --- a/R/pkg/R/mllib_tree.R +++ b/R/pkg/R/mllib_tree.R @@ -52,12 +52,14 @@ summary.treeEnsemble <- function(model) { numFeatures <- callJMethod(jobj, "numFeatures") features <- callJMethod(jobj, "features") featureImportances <- callJMethod(callJMethod(jobj, "featureImportances"), "toString") + maxDepth <- callJMethod(jobj, "maxDepth") numTrees <- callJMethod(jobj, "numTrees") treeWeights <- callJMethod(jobj, "treeWeights") list(formula = formula, numFeatures = numFeatures, features = features, featureImportances = featureImportances, + maxDepth = maxDepth, numTrees = numTrees, treeWeights = treeWeights, jobj = jobj) @@ -70,6 +72,7 @@ print.summary.treeEnsemble <- function(x) { cat("\nNumber of features: ", x$numFeatures) cat("\nFeatures: ", unlist(x$features)) cat("\nFeature importances: ", x$featureImportances) + cat("\nMax Depth: ", x$maxDepth) cat("\nNumber of trees: ", x$numTrees) cat("\nTree weights: ", unlist(x$treeWeights)) @@ -197,8 +200,8 @@ setMethod("spark.gbt", signature(data = "SparkDataFrame", formula = "formula"), #' @return \code{summary} returns summary information of the fitted model, which is a list. #' The list of components includes \code{formula} (formula), #' \code{numFeatures} (number of features), \code{features} (list of features), -#' \code{featureImportances} (feature importances), \code{numTrees} (number of trees), -#' and \code{treeWeights} (tree weights). +#' \code{featureImportances} (feature importances), \code{maxDepth} (max depth of trees), +#' \code{numTrees} (number of trees), and \code{treeWeights} (tree weights). #' @rdname spark.gbt #' @aliases summary,GBTRegressionModel-method #' @export @@ -403,8 +406,8 @@ setMethod("spark.randomForest", signature(data = "SparkDataFrame", formula = "fo #' @return \code{summary} returns summary information of the fitted model, which is a list. #' The list of components includes \code{formula} (formula), #' \code{numFeatures} (number of features), \code{features} (list of features), -#' \code{featureImportances} (feature importances), \code{numTrees} (number of trees), -#' and \code{treeWeights} (tree weights). +#' \code{featureImportances} (feature importances), \code{maxDepth} (max depth of trees), +#' \code{numTrees} (number of trees), and \code{treeWeights} (tree weights). #' @rdname spark.randomForest #' @aliases summary,RandomForestRegressionModel-method #' @export diff --git a/R/pkg/inst/tests/testthat/test_mllib_tree.R b/R/pkg/inst/tests/testthat/test_mllib_tree.R index e6fda251ebea..e0802a9b02d1 100644 --- a/R/pkg/inst/tests/testthat/test_mllib_tree.R +++ b/R/pkg/inst/tests/testthat/test_mllib_tree.R @@ -39,6 +39,7 @@ test_that("spark.gbt", { tolerance = 1e-4) stats <- summary(model) expect_equal(stats$numTrees, 20) + expect_equal(stats$maxDepth, 5) expect_equal(stats$formula, "Employed ~ .") expect_equal(stats$numFeatures, 6) expect_equal(length(stats$treeWeights), 20) @@ -53,6 +54,7 @@ test_that("spark.gbt", { expect_equal(stats$numFeatures, stats2$numFeatures) expect_equal(stats$features, stats2$features) expect_equal(stats$featureImportances, stats2$featureImportances) + expect_equal(stats$maxDepth, stats2$maxDepth) expect_equal(stats$numTrees, stats2$numTrees) expect_equal(stats$treeWeights, stats2$treeWeights) @@ -66,6 +68,7 @@ test_that("spark.gbt", { stats <- summary(model) expect_equal(stats$numFeatures, 2) expect_equal(stats$numTrees, 20) + expect_equal(stats$maxDepth, 5) expect_error(capture.output(stats), NA) expect_true(length(capture.output(stats)) > 6) predictions <- collect(predict(model, data))$prediction @@ -93,6 +96,7 @@ test_that("spark.gbt", { expect_equal(iris2$NumericSpecies, as.double(collect(predict(m, df))$prediction)) expect_equal(s$numFeatures, 5) expect_equal(s$numTrees, 20) + expect_equal(stats$maxDepth, 5) # spark.gbt classification can work on libsvm data data <- read.df(absoluteSparkPath("data/mllib/sample_binary_classification_data.txt"), @@ -116,6 +120,7 @@ test_that("spark.randomForest", { stats <- summary(model) expect_equal(stats$numTrees, 1) + expect_equal(stats$maxDepth, 5) expect_error(capture.output(stats), NA) expect_true(length(capture.output(stats)) > 6) @@ -129,6 +134,7 @@ test_that("spark.randomForest", { tolerance = 1e-4) stats <- summary(model) expect_equal(stats$numTrees, 20) + expect_equal(stats$maxDepth, 5) modelPath <- tempfile(pattern = "spark-randomForestRegression", fileext = ".tmp") write.ml(model, modelPath) @@ -141,6 +147,7 @@ test_that("spark.randomForest", { expect_equal(stats$features, stats2$features) expect_equal(stats$featureImportances, stats2$featureImportances) expect_equal(stats$numTrees, stats2$numTrees) + expect_equal(stats$maxDepth, stats2$maxDepth) expect_equal(stats$treeWeights, stats2$treeWeights) unlink(modelPath) @@ -153,6 +160,7 @@ test_that("spark.randomForest", { stats <- summary(model) expect_equal(stats$numFeatures, 2) expect_equal(stats$numTrees, 20) + expect_equal(stats$maxDepth, 5) expect_error(capture.output(stats), NA) expect_true(length(capture.output(stats)) > 6) # Test string prediction values @@ -187,6 +195,8 @@ test_that("spark.randomForest", { stats <- summary(model) expect_equal(stats$numFeatures, 2) expect_equal(stats$numTrees, 20) + expect_equal(stats$maxDepth, 5) + # Test numeric prediction values predictions <- collect(predict(model, data))$prediction expect_equal(length(grep("1.0", predictions)), 50) diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/GBTClassificationWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/GBTClassificationWrapper.scala index aacb41ee2659..c07eadb30a4d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/GBTClassificationWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/GBTClassificationWrapper.scala @@ -44,6 +44,7 @@ private[r] class GBTClassifierWrapper private ( lazy val featureImportances: Vector = gbtcModel.featureImportances lazy val numTrees: Int = gbtcModel.getNumTrees lazy val treeWeights: Array[Double] = gbtcModel.treeWeights + lazy val maxDepth: Int = gbtcModel.getMaxDepth def summary: String = gbtcModel.toDebugString diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/GBTRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/GBTRegressionWrapper.scala index 585077588eb9..b568d7859221 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/GBTRegressionWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/GBTRegressionWrapper.scala @@ -42,6 +42,7 @@ private[r] class GBTRegressorWrapper private ( lazy val featureImportances: Vector = gbtrModel.featureImportances lazy val numTrees: Int = gbtrModel.getNumTrees lazy val treeWeights: Array[Double] = gbtrModel.treeWeights + lazy val maxDepth: Int = gbtrModel.getMaxDepth def summary: String = gbtrModel.toDebugString diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala index 366f375b5858..8a83d4e980f7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala @@ -44,6 +44,7 @@ private[r] class RandomForestClassifierWrapper private ( lazy val featureImportances: Vector = rfcModel.featureImportances lazy val numTrees: Int = rfcModel.getNumTrees lazy val treeWeights: Array[Double] = rfcModel.treeWeights + lazy val maxDepth: Int = rfcModel.getMaxDepth def summary: String = rfcModel.toDebugString diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestRegressionWrapper.scala index 4b9a3a731da9..038bd79c7022 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestRegressionWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestRegressionWrapper.scala @@ -42,6 +42,7 @@ private[r] class RandomForestRegressorWrapper private ( lazy val featureImportances: Vector = rfrModel.featureImportances lazy val numTrees: Int = rfrModel.getNumTrees lazy val treeWeights: Array[Double] = rfrModel.treeWeights + lazy val maxDepth: Int = rfrModel.getMaxDepth def summary: String = rfrModel.toDebugString From 0a4d06a7c3db9fec2b6f050a631e8b59b0e9376e Mon Sep 17 00:00:00 2001 From: uncleGen Date: Sun, 12 Mar 2017 17:46:31 -0700 Subject: [PATCH 014/512] [SPARK-19853][SS] uppercase kafka topics fail when startingOffsets are SpecificOffsets When using the KafkaSource with Structured Streaming, consumer assignments are not what the user expects if startingOffsets is set to an explicit set of topics/partitions in JSON where the topic(s) happen to have uppercase characters. When StartingOffsets is constructed, the original string value from options is transformed toLowerCase to make matching on "earliest" and "latest" case insensitive. However, the toLowerCase JSON is passed to SpecificOffsets for the terminal condition, so topic names may not be what the user intended by the time assignments are made with the underlying KafkaConsumer. KafkaSourceProvider.scala: ``` val startingOffsets = caseInsensitiveParams.get(STARTING_OFFSETS_OPTION_KEY).map(_.trim.toLowerCase) match { case Some("latest") => LatestOffsets case Some("earliest") => EarliestOffsets case Some(json) => SpecificOffsets(JsonUtils.partitionOffsets(json)) case None => LatestOffsets } ``` Thank cbowden for reporting. Jenkins Author: uncleGen Closes #17209 from uncleGen/SPARK-19853. --- .../sql/kafka010/KafkaSourceProvider.scala | 69 ++++++++++--------- .../spark/sql/kafka010/KafkaSourceSuite.scala | 19 +++++ 2 files changed, 54 insertions(+), 34 deletions(-) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala index febe3c217122..58b52692b57c 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala @@ -82,13 +82,8 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister .map { k => k.drop(6).toString -> parameters(k) } .toMap - val startingStreamOffsets = - caseInsensitiveParams.get(STARTING_OFFSETS_OPTION_KEY).map(_.trim.toLowerCase) match { - case Some("latest") => LatestOffsetRangeLimit - case Some("earliest") => EarliestOffsetRangeLimit - case Some(json) => SpecificOffsetRangeLimit(JsonUtils.partitionOffsets(json)) - case None => LatestOffsetRangeLimit - } + val startingStreamOffsets = KafkaSourceProvider.getKafkaOffsetRangeLimit(caseInsensitiveParams, + STARTING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit) val kafkaOffsetReader = new KafkaOffsetReader( strategy(caseInsensitiveParams), @@ -128,19 +123,13 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister .map { k => k.drop(6).toString -> parameters(k) } .toMap - val startingRelationOffsets = - caseInsensitiveParams.get(STARTING_OFFSETS_OPTION_KEY).map(_.trim.toLowerCase) match { - case Some("earliest") => EarliestOffsetRangeLimit - case Some(json) => SpecificOffsetRangeLimit(JsonUtils.partitionOffsets(json)) - case None => EarliestOffsetRangeLimit - } + val startingRelationOffsets = KafkaSourceProvider.getKafkaOffsetRangeLimit( + caseInsensitiveParams, STARTING_OFFSETS_OPTION_KEY, EarliestOffsetRangeLimit) + assert(startingRelationOffsets != LatestOffsetRangeLimit) - val endingRelationOffsets = - caseInsensitiveParams.get(ENDING_OFFSETS_OPTION_KEY).map(_.trim.toLowerCase) match { - case Some("latest") => LatestOffsetRangeLimit - case Some(json) => SpecificOffsetRangeLimit(JsonUtils.partitionOffsets(json)) - case None => LatestOffsetRangeLimit - } + val endingRelationOffsets = KafkaSourceProvider.getKafkaOffsetRangeLimit(caseInsensitiveParams, + ENDING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit) + assert(endingRelationOffsets != EarliestOffsetRangeLimit) val kafkaOffsetReader = new KafkaOffsetReader( strategy(caseInsensitiveParams), @@ -388,34 +377,34 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister private def validateBatchOptions(caseInsensitiveParams: Map[String, String]) = { // Batch specific options - caseInsensitiveParams.get(STARTING_OFFSETS_OPTION_KEY).map(_.trim.toLowerCase) match { - case Some("earliest") => // good to go - case Some("latest") => + KafkaSourceProvider.getKafkaOffsetRangeLimit( + caseInsensitiveParams, STARTING_OFFSETS_OPTION_KEY, EarliestOffsetRangeLimit) match { + case EarliestOffsetRangeLimit => // good to go + case LatestOffsetRangeLimit => throw new IllegalArgumentException("starting offset can't be latest " + "for batch queries on Kafka") - case Some(json) => (SpecificOffsetRangeLimit(JsonUtils.partitionOffsets(json))) - .partitionOffsets.foreach { + case SpecificOffsetRangeLimit(partitionOffsets) => + partitionOffsets.foreach { case (tp, off) if off == KafkaOffsetRangeLimit.LATEST => throw new IllegalArgumentException(s"startingOffsets for $tp can't " + "be latest for batch queries on Kafka") case _ => // ignore } - case _ => // default to earliest } - caseInsensitiveParams.get(ENDING_OFFSETS_OPTION_KEY).map(_.trim.toLowerCase) match { - case Some("earliest") => + KafkaSourceProvider.getKafkaOffsetRangeLimit( + caseInsensitiveParams, ENDING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit) match { + case EarliestOffsetRangeLimit => throw new IllegalArgumentException("ending offset can't be earliest " + "for batch queries on Kafka") - case Some("latest") => // good to go - case Some(json) => (SpecificOffsetRangeLimit(JsonUtils.partitionOffsets(json))) - .partitionOffsets.foreach { + case LatestOffsetRangeLimit => // good to go + case SpecificOffsetRangeLimit(partitionOffsets) => + partitionOffsets.foreach { case (tp, off) if off == KafkaOffsetRangeLimit.EARLIEST => throw new IllegalArgumentException(s"ending offset for $tp can't be " + "earliest for batch queries on Kafka") case _ => // ignore } - case _ => // default to latest } validateGeneralOptions(caseInsensitiveParams) @@ -432,7 +421,7 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister def set(key: String, value: Object): this.type = { map.put(key, value) - logInfo(s"$module: Set $key to $value, earlier value: ${kafkaParams.get(key).getOrElse("")}") + logInfo(s"$module: Set $key to $value, earlier value: ${kafkaParams.getOrElse(key, "")}") this } @@ -450,10 +439,22 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister private[kafka010] object KafkaSourceProvider { private val STRATEGY_OPTION_KEYS = Set("subscribe", "subscribepattern", "assign") - private val STARTING_OFFSETS_OPTION_KEY = "startingoffsets" - private val ENDING_OFFSETS_OPTION_KEY = "endingoffsets" + private[kafka010] val STARTING_OFFSETS_OPTION_KEY = "startingoffsets" + private[kafka010] val ENDING_OFFSETS_OPTION_KEY = "endingoffsets" private val FAIL_ON_DATA_LOSS_OPTION_KEY = "failondataloss" val TOPIC_OPTION_KEY = "topic" private val deserClassName = classOf[ByteArrayDeserializer].getName + + def getKafkaOffsetRangeLimit( + params: Map[String, String], + offsetOptionKey: String, + defaultOffsets: KafkaOffsetRangeLimit): KafkaOffsetRangeLimit = { + params.get(offsetOptionKey).map(_.trim) match { + case Some(offset) if offset.toLowerCase == "latest" => LatestOffsetRangeLimit + case Some(offset) if offset.toLowerCase == "earliest" => EarliestOffsetRangeLimit + case Some(json) => SpecificOffsetRangeLimit(JsonUtils.partitionOffsets(json)) + case None => defaultOffsets + } + } } diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala index 534fb77c9ce1..bf6aad671a18 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala @@ -37,6 +37,7 @@ import org.apache.spark.SparkContext import org.apache.spark.sql.ForeachWriter import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.functions.{count, window} +import org.apache.spark.sql.kafka010.KafkaSourceProvider._ import org.apache.spark.sql.streaming.{ProcessingTime, StreamTest} import org.apache.spark.sql.test.{SharedSQLContext, TestSparkSession} import org.apache.spark.util.Utils @@ -606,6 +607,24 @@ class KafkaSourceSuite extends KafkaSourceTest { assert(query.exception.isEmpty) } + test("get offsets from case insensitive parameters") { + for ((optionKey, optionValue, answer) <- Seq( + (STARTING_OFFSETS_OPTION_KEY, "earLiEst", EarliestOffsetRangeLimit), + (ENDING_OFFSETS_OPTION_KEY, "laTest", LatestOffsetRangeLimit), + (STARTING_OFFSETS_OPTION_KEY, """{"topic-A":{"0":23}}""", + SpecificOffsetRangeLimit(Map(new TopicPartition("topic-A", 0) -> 23))))) { + val offset = getKafkaOffsetRangeLimit(Map(optionKey -> optionValue), optionKey, answer) + assert(offset === answer) + } + + for ((optionKey, answer) <- Seq( + (STARTING_OFFSETS_OPTION_KEY, EarliestOffsetRangeLimit), + (ENDING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit))) { + val offset = getKafkaOffsetRangeLimit(Map.empty, optionKey, answer) + assert(offset === answer) + } + } + private def newTopic(): String = s"topic-${topicId.getAndIncrement()}" private def assignString(topic: String, partitions: Iterable[Int]): String = { From 9456688547522a62f1e7520e9b3564550c57aa5d Mon Sep 17 00:00:00 2001 From: Tejas Patil Date: Sun, 12 Mar 2017 20:08:44 -0700 Subject: [PATCH 015/512] [SPARK-17495][SQL] Support date, timestamp and interval types in Hive hash ## What changes were proposed in this pull request? - Timestamp hashing is done as per [TimestampWritable.hashCode()](https://github.com/apache/hive/blob/ff67cdda1c538dc65087878eeba3e165cf3230f4/serde/src/java/org/apache/hadoop/hive/serde2/io/TimestampWritable.java#L406) in Hive - Interval hashing is done as per [HiveIntervalDayTime.hashCode()](https://github.com/apache/hive/blob/ff67cdda1c538dc65087878eeba3e165cf3230f4/storage-api/src/java/org/apache/hadoop/hive/common/type/HiveIntervalDayTime.java#L178). Note that there are inherent differences in how Hive and Spark store intervals under the hood which limits the ability to be in completely sync with hive's hashing function. I have explained this in the method doc. - Date type was already supported. This PR adds test for that. ## How was this patch tested? Added unit tests Author: Tejas Patil Closes #17062 from tejasapatil/SPARK-17495_time_related_types. --- .../spark/sql/catalyst/expressions/hash.scala | 71 +++++- .../expressions/HashExpressionsSuite.scala | 208 +++++++++++++++++- 2 files changed, 268 insertions(+), 11 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala index 03101b4bfc5f..2a5963d37f5e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala @@ -335,6 +335,8 @@ abstract class HashExpression[E] extends Expression { } } + protected def genHashTimestamp(t: String, result: String): String = genHashLong(t, result) + protected def genHashCalendarInterval(input: String, result: String): String = { val microsecondsHash = s"$hasherClassName.hashLong($input.microseconds, $result)" s"$result = $hasherClassName.hashInt($input.months, $microsecondsHash);" @@ -400,7 +402,8 @@ abstract class HashExpression[E] extends Expression { case NullType => "" case BooleanType => genHashBoolean(input, result) case ByteType | ShortType | IntegerType | DateType => genHashInt(input, result) - case LongType | TimestampType => genHashLong(input, result) + case LongType => genHashLong(input, result) + case TimestampType => genHashTimestamp(input, result) case FloatType => genHashFloat(input, result) case DoubleType => genHashDouble(input, result) case d: DecimalType => genHashDecimal(ctx, d, input, result) @@ -433,6 +436,10 @@ abstract class InterpretedHashFunction { protected def hashUnsafeBytes(base: AnyRef, offset: Long, length: Int, seed: Long): Long + /** + * Computes hash of a given `value` of type `dataType`. The caller needs to check the validity + * of input `value`. + */ def hash(value: Any, dataType: DataType, seed: Long): Long = { value match { case null => seed @@ -580,8 +587,6 @@ object XxHash64Function extends InterpretedHashFunction { * * We should use this hash function for both shuffle and bucket of Hive tables, so that * we can guarantee shuffle and bucketing have same data distribution - * - * TODO: Support date related types */ @ExpressionDescription( usage = "_FUNC_(expr1, expr2, ...) - Returns a hash value of the arguments.") @@ -648,11 +653,16 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] { override protected def genHashCalendarInterval(input: String, result: String): String = { s""" - $result = (31 * $hasherClassName.hashInt($input.months)) + - $hasherClassName.hashLong($input.microseconds);" + $result = (int) + ${HiveHashFunction.getClass.getName.stripSuffix("$")}.hashCalendarInterval($input); """ } + override protected def genHashTimestamp(input: String, result: String): String = + s""" + $result = (int) ${HiveHashFunction.getClass.getName.stripSuffix("$")}.hashTimestamp($input); + """ + override protected def genHashString(input: String, result: String): String = { val baseObject = s"$input.getBaseObject()" val baseOffset = s"$input.getBaseOffset()" @@ -781,6 +791,49 @@ object HiveHashFunction extends InterpretedHashFunction { result } + /** + * Mimics TimestampWritable.hashCode() in Hive + */ + def hashTimestamp(timestamp: Long): Long = { + val timestampInSeconds = timestamp / 1000000 + val nanoSecondsPortion = (timestamp % 1000000) * 1000 + + var result = timestampInSeconds + result <<= 30 // the nanosecond part fits in 30 bits + result |= nanoSecondsPortion + ((result >>> 32) ^ result).toInt + } + + /** + * Hive allows input intervals to be defined using units below but the intervals + * have to be from the same category: + * - year, month (stored as HiveIntervalYearMonth) + * - day, hour, minute, second, nanosecond (stored as HiveIntervalDayTime) + * + * eg. (INTERVAL '30' YEAR + INTERVAL '-23' DAY) fails in Hive + * + * This method mimics HiveIntervalDayTime.hashCode() in Hive. + * + * Two differences wrt Hive due to how intervals are stored in Spark vs Hive: + * + * - If the `INTERVAL` is backed as HiveIntervalYearMonth in Hive, then this method will not + * produce Hive compatible result. The reason being Spark's representation of calendar does not + * have such categories based on the interval and is unified. + * + * - Spark's [[CalendarInterval]] has precision upto microseconds but Hive's + * HiveIntervalDayTime can store data with precision upto nanoseconds. So, any input intervals + * with nanosecond values will lead to wrong output hashes (ie. non adherent with Hive output) + */ + def hashCalendarInterval(calendarInterval: CalendarInterval): Long = { + val totalSeconds = calendarInterval.microseconds / CalendarInterval.MICROS_PER_SECOND.toInt + val result: Int = (17 * 37) + (totalSeconds ^ totalSeconds >> 32).toInt + + val nanoSeconds = + (calendarInterval.microseconds - + (totalSeconds * CalendarInterval.MICROS_PER_SECOND.toInt)).toInt * 1000 + (result * 37) + nanoSeconds + } + override def hash(value: Any, dataType: DataType, seed: Long): Long = { value match { case null => 0 @@ -834,10 +887,10 @@ object HiveHashFunction extends InterpretedHashFunction { } result - case d: Decimal => - normalizeDecimal(d.toJavaBigDecimal).hashCode() - - case _ => super.hash(value, dataType, seed) + case d: Decimal => normalizeDecimal(d.toJavaBigDecimal).hashCode() + case timestamp: Long if dataType.isInstanceOf[TimestampType] => hashTimestamp(timestamp) + case calendarInterval: CalendarInterval => hashCalendarInterval(calendarInterval) + case _ => super.hash(value, dataType, 0) } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala index 0c77dc2709da..59fc8eaf73d6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala @@ -18,18 +18,20 @@ package org.apache.spark.sql.catalyst.expressions import java.nio.charset.StandardCharsets +import java.util.TimeZone import scala.collection.mutable.ArrayBuffer import org.apache.commons.codec.digest.DigestUtils +import org.scalatest.exceptions.TestFailedException import org.apache.spark.SparkFunSuite import org.apache.spark.sql.{RandomDataGenerator, Row} import org.apache.spark.sql.catalyst.encoders.{ExamplePointUDT, RowEncoder} import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection -import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils, GenericArrayData} import org.apache.spark.sql.types.{ArrayType, StructType, _} -import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} class HashExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val random = new scala.util.Random @@ -168,6 +170,208 @@ class HashExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { // scalastyle:on nonascii } + test("hive-hash for date type") { + def checkHiveHashForDateType(dateString: String, expected: Long): Unit = { + checkHiveHash( + DateTimeUtils.stringToDate(UTF8String.fromString(dateString)).get, + DateType, + expected) + } + + // basic case + checkHiveHashForDateType("2017-01-01", 17167) + + // boundary cases + checkHiveHashForDateType("0000-01-01", -719530) + checkHiveHashForDateType("9999-12-31", 2932896) + + // epoch + checkHiveHashForDateType("1970-01-01", 0) + + // before epoch + checkHiveHashForDateType("1800-01-01", -62091) + + // Invalid input: bad date string. Hive returns 0 for such cases + intercept[NoSuchElementException](checkHiveHashForDateType("0-0-0", 0)) + intercept[NoSuchElementException](checkHiveHashForDateType("-1212-01-01", 0)) + intercept[NoSuchElementException](checkHiveHashForDateType("2016-99-99", 0)) + + // Invalid input: Empty string. Hive returns 0 for this case + intercept[NoSuchElementException](checkHiveHashForDateType("", 0)) + + // Invalid input: February 30th for a leap year. Hive supports this but Spark doesn't + intercept[NoSuchElementException](checkHiveHashForDateType("2016-02-30", 16861)) + } + + test("hive-hash for timestamp type") { + def checkHiveHashForTimestampType( + timestamp: String, + expected: Long, + timeZone: TimeZone = TimeZone.getTimeZone("UTC")): Unit = { + checkHiveHash( + DateTimeUtils.stringToTimestamp(UTF8String.fromString(timestamp), timeZone).get, + TimestampType, + expected) + } + + // basic case + checkHiveHashForTimestampType("2017-02-24 10:56:29", 1445725271) + + // with higher precision + checkHiveHashForTimestampType("2017-02-24 10:56:29.111111", 1353936655) + + // with different timezone + checkHiveHashForTimestampType("2017-02-24 10:56:29", 1445732471, + TimeZone.getTimeZone("US/Pacific")) + + // boundary cases + checkHiveHashForTimestampType("0001-01-01 00:00:00", 1645926784) + checkHiveHashForTimestampType("9999-01-01 00:00:00", -1081818240) + + // epoch + checkHiveHashForTimestampType("1970-01-01 00:00:00", 0) + + // before epoch + checkHiveHashForTimestampType("1800-01-01 03:12:45", -267420885) + + // Invalid input: bad timestamp string. Hive returns 0 for such cases + intercept[NoSuchElementException](checkHiveHashForTimestampType("0-0-0 0:0:0", 0)) + intercept[NoSuchElementException](checkHiveHashForTimestampType("-99-99-99 99:99:45", 0)) + intercept[NoSuchElementException](checkHiveHashForTimestampType("555555-55555-5555", 0)) + + // Invalid input: Empty string. Hive returns 0 for this case + intercept[NoSuchElementException](checkHiveHashForTimestampType("", 0)) + + // Invalid input: February 30th is a leap year. Hive supports this but Spark doesn't + intercept[NoSuchElementException](checkHiveHashForTimestampType("2016-02-30 00:00:00", 0)) + + // Invalid input: Hive accepts upto 9 decimal place precision but Spark uses upto 6 + intercept[TestFailedException](checkHiveHashForTimestampType("2017-02-24 10:56:29.11111111", 0)) + } + + test("hive-hash for CalendarInterval type") { + def checkHiveHashForIntervalType(interval: String, expected: Long): Unit = { + checkHiveHash(CalendarInterval.fromString(interval), CalendarIntervalType, expected) + } + + // ----- MICROSEC ----- + + // basic case + checkHiveHashForIntervalType("interval 1 microsecond", 24273) + + // negative + checkHiveHashForIntervalType("interval -1 microsecond", 22273) + + // edge / boundary cases + checkHiveHashForIntervalType("interval 0 microsecond", 23273) + checkHiveHashForIntervalType("interval 999 microsecond", 1022273) + checkHiveHashForIntervalType("interval -999 microsecond", -975727) + + // ----- MILLISEC ----- + + // basic case + checkHiveHashForIntervalType("interval 1 millisecond", 1023273) + + // negative + checkHiveHashForIntervalType("interval -1 millisecond", -976727) + + // edge / boundary cases + checkHiveHashForIntervalType("interval 0 millisecond", 23273) + checkHiveHashForIntervalType("interval 999 millisecond", 999023273) + checkHiveHashForIntervalType("interval -999 millisecond", -998976727) + + // ----- SECOND ----- + + // basic case + checkHiveHashForIntervalType("interval 1 second", 23310) + + // negative + checkHiveHashForIntervalType("interval -1 second", 23273) + + // edge / boundary cases + checkHiveHashForIntervalType("interval 0 second", 23273) + checkHiveHashForIntervalType("interval 2147483647 second", -2147460412) + checkHiveHashForIntervalType("interval -2147483648 second", -2147460412) + + // Out of range for both Hive and Spark + // Hive throws an exception. Spark overflows and returns wrong output + // checkHiveHashForIntervalType("interval 9999999999 second", 0) + + // ----- MINUTE ----- + + // basic cases + checkHiveHashForIntervalType("interval 1 minute", 25493) + + // negative + checkHiveHashForIntervalType("interval -1 minute", 25456) + + // edge / boundary cases + checkHiveHashForIntervalType("interval 0 minute", 23273) + checkHiveHashForIntervalType("interval 2147483647 minute", 21830) + checkHiveHashForIntervalType("interval -2147483648 minute", 22163) + + // Out of range for both Hive and Spark + // Hive throws an exception. Spark overflows and returns wrong output + // checkHiveHashForIntervalType("interval 9999999999 minute", 0) + + // ----- HOUR ----- + + // basic case + checkHiveHashForIntervalType("interval 1 hour", 156473) + + // negative + checkHiveHashForIntervalType("interval -1 hour", 156436) + + // edge / boundary cases + checkHiveHashForIntervalType("interval 0 hour", 23273) + checkHiveHashForIntervalType("interval 2147483647 hour", -62308) + checkHiveHashForIntervalType("interval -2147483648 hour", -43327) + + // Out of range for both Hive and Spark + // Hive throws an exception. Spark overflows and returns wrong output + // checkHiveHashForIntervalType("interval 9999999999 hour", 0) + + // ----- DAY ----- + + // basic cases + checkHiveHashForIntervalType("interval 1 day", 3220073) + + // negative + checkHiveHashForIntervalType("interval -1 day", 3220036) + + // edge / boundary cases + checkHiveHashForIntervalType("interval 0 day", 23273) + checkHiveHashForIntervalType("interval 106751991 day", -451506760) + checkHiveHashForIntervalType("interval -106751991 day", -451514123) + + // Hive supports `day` for a longer range but Spark's range is smaller + // The check for range is done at the parser level so this does not fail in Spark + // checkHiveHashForIntervalType("interval -2147483648 day", -1575127) + // checkHiveHashForIntervalType("interval 2147483647 day", -4767228) + + // Out of range for both Hive and Spark + // Hive throws an exception. Spark overflows and returns wrong output + // checkHiveHashForIntervalType("interval 9999999999 day", 0) + + // ----- MIX ----- + + checkHiveHashForIntervalType("interval 0 day 0 hour", 23273) + checkHiveHashForIntervalType("interval 0 day 0 hour 0 minute", 23273) + checkHiveHashForIntervalType("interval 0 day 0 hour 0 minute 0 second", 23273) + checkHiveHashForIntervalType("interval 0 day 0 hour 0 minute 0 second 0 millisecond", 23273) + checkHiveHashForIntervalType( + "interval 0 day 0 hour 0 minute 0 second 0 millisecond 0 microsecond", 23273) + + checkHiveHashForIntervalType("interval 6 day 15 hour", 21202073) + checkHiveHashForIntervalType("interval 5 day 4 hour 8 minute", 16557833) + checkHiveHashForIntervalType("interval -23 day 56 hour -1111113 minute 9898989 second", + -2128468593) + checkHiveHashForIntervalType("interval 66 day 12 hour 39 minute 23 second 987 millisecond", + 1199697904) + checkHiveHashForIntervalType( + "interval 66 day 12 hour 39 minute 23 second 987 millisecond 123 microsecond", 1199820904) + } + test("hive-hash for array") { // empty array checkHiveHash( From 05887fc3d8d517b416992ee870d0f865b1f9a3d0 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Sun, 12 Mar 2017 23:16:45 -0700 Subject: [PATCH 016/512] [SPARK-19916][SQL] simplify bad file handling ## What changes were proposed in this pull request? We should only have one centre place to try catch the exception for corrupted files. ## How was this patch tested? existing test Author: Wenchen Fan Closes #17253 from cloud-fan/bad-file. --- .../execution/datasources/FileFormat.scala | 4 +- .../execution/datasources/FileScanRDD.scala | 88 +++++++++---------- .../parquet/ParquetFileFormat.scala | 14 --- 3 files changed, 43 insertions(+), 63 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala index 6784ee243c93..dacf46295352 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala @@ -90,7 +90,7 @@ trait FileFormat { * @param options A set of string -> string configuration options. * @return */ - def buildReader( + protected def buildReader( sparkSession: SparkSession, dataSchema: StructType, partitionSchema: StructType, @@ -98,8 +98,6 @@ trait FileFormat { filters: Seq[Filter], options: Map[String, String], hadoopConf: Configuration): PartitionedFile => Iterator[InternalRow] = { - // TODO: Remove this default implementation when the other formats have been ported - // Until then we guard in [[FileSourceStrategy]] to only call this method on supported formats. throw new UnsupportedOperationException(s"buildReader is not supported for $this") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala index 14f721d6a790..a89d172a911a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution.datasources -import java.io.IOException +import java.io.{FileNotFoundException, IOException} import scala.collection.mutable @@ -44,7 +44,7 @@ case class PartitionedFile( filePath: String, start: Long, length: Long, - locations: Array[String] = Array.empty) { + @transient locations: Array[String] = Array.empty) { override def toString: String = { s"path: $filePath, range: $start-${start + length}, partition values: $partitionValues" } @@ -121,6 +121,20 @@ class FileScanRDD( nextElement } + private def readCurrentFile(): Iterator[InternalRow] = { + try { + readFunction(currentFile) + } catch { + case e: FileNotFoundException => + throw new FileNotFoundException( + e.getMessage + "\n" + + "It is possible the underlying files have been updated. " + + "You can explicitly invalidate the cache in Spark by " + + "running 'REFRESH TABLE tableName' command in SQL or " + + "by recreating the Dataset/DataFrame involved.") + } + } + /** Advances to the next file. Returns true if a new non-empty iterator is available. */ private def nextIterator(): Boolean = { updateBytesReadWithFileSize() @@ -130,54 +144,36 @@ class FileScanRDD( // Sets InputFileBlockHolder for the file block's information InputFileBlockHolder.set(currentFile.filePath, currentFile.start, currentFile.length) - try { - if (ignoreCorruptFiles) { - currentIterator = new NextIterator[Object] { - private val internalIter = { - try { - // The readFunction may read files before consuming the iterator. - // E.g., vectorized Parquet reader. - readFunction(currentFile) - } catch { - case e @(_: RuntimeException | _: IOException) => - logWarning(s"Skipped the rest content in the corrupted file: $currentFile", e) - Iterator.empty - } - } - - override def getNext(): AnyRef = { - try { - if (internalIter.hasNext) { - internalIter.next() - } else { - finished = true - null - } - } catch { - case e: IOException => - logWarning(s"Skipped the rest content in the corrupted file: $currentFile", e) - finished = true - null + if (ignoreCorruptFiles) { + currentIterator = new NextIterator[Object] { + // The readFunction may read some bytes before consuming the iterator, e.g., + // vectorized Parquet reader. Here we use lazy val to delay the creation of + // iterator so that we will throw exception in `getNext`. + private lazy val internalIter = readCurrentFile() + + override def getNext(): AnyRef = { + try { + if (internalIter.hasNext) { + internalIter.next() + } else { + finished = true + null } + } catch { + // Throw FileNotFoundException even `ignoreCorruptFiles` is true + case e: FileNotFoundException => throw e + case e @ (_: RuntimeException | _: IOException) => + logWarning( + s"Skipped the rest of the content in the corrupted file: $currentFile", e) + finished = true + null } - - override def close(): Unit = {} } - } else { - currentIterator = readFunction(currentFile) + + override def close(): Unit = {} } - } catch { - case e: IOException if ignoreCorruptFiles => - logWarning(s"Skipped the rest content in the corrupted file: $currentFile", e) - currentIterator = Iterator.empty - case e: java.io.FileNotFoundException => - throw new java.io.FileNotFoundException( - e.getMessage + "\n" + - "It is possible the underlying files have been updated. " + - "You can explicitly invalidate the cache in Spark by " + - "running 'REFRESH TABLE tableName' command in SQL or " + - "by recreating the Dataset/DataFrame involved." - ) + } else { + currentIterator = readCurrentFile() } hasNext diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala index 5313c2f3746a..062aa5c8ea62 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala @@ -283,20 +283,6 @@ class ParquetFileFormat filters: Seq[Filter], options: Map[String, String], hadoopConf: Configuration): (PartitionedFile) => Iterator[InternalRow] = { - // For Parquet data source, `buildReader` already handles partition values appending. Here we - // simply delegate to `buildReader`. - buildReader( - sparkSession, dataSchema, partitionSchema, requiredSchema, filters, options, hadoopConf) - } - - override def buildReader( - sparkSession: SparkSession, - dataSchema: StructType, - partitionSchema: StructType, - requiredSchema: StructType, - filters: Seq[Filter], - options: Map[String, String], - hadoopConf: Configuration): PartitionedFile => Iterator[InternalRow] = { hadoopConf.set(ParquetInputFormat.READ_SUPPORT_CLASS, classOf[ParquetReadSupport].getName) hadoopConf.set( ParquetReadSupport.SPARK_ROW_REQUESTED_SCHEMA, From 72c66dbbb4dacaf5fd77bca58c952f34eba7c147 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Mon, 13 Mar 2017 16:30:15 -0700 Subject: [PATCH 017/512] [MINOR][ML] Improve MLWriter overwrite error message ## What changes were proposed in this pull request? Give proper syntax for Java and Python in addition to Scala. ## How was this patch tested? Manually. Author: Joseph K. Bradley Closes #17215 from jkbradley/write-err-msg. --- .../src/main/scala/org/apache/spark/ml/util/ReadWrite.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala index 09bddcdb810b..a8b80031faf8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -104,8 +104,9 @@ abstract class MLWriter extends BaseReadWrite with Logging { // TODO: Revert back to the original content if save is not successful. fs.delete(qualifiedOutputPath, true) } else { - throw new IOException( - s"Path $path already exists. Please use write.overwrite().save(path) to overwrite it.") + throw new IOException(s"Path $path already exists. To overwrite it, " + + s"please use write.overwrite().save(path) for Scala and use " + + s"write().overwrite().save(path) for Java and Python.") } } saveImpl(path) From 4dc3a8171c31e11aafa85200d3928b1745aa32bd Mon Sep 17 00:00:00 2001 From: Xiao Li Date: Tue, 14 Mar 2017 12:06:01 +0800 Subject: [PATCH 018/512] [SPARK-19924][SQL] Handle InvocationTargetException for all Hive Shim ### What changes were proposed in this pull request? Since we are using shim for most Hive metastore APIs, the exceptions thrown by the underlying method of Method.invoke() are wrapped by `InvocationTargetException`. Instead of doing it one by one, we should handle all of them in the `withClient`. If any of them is missing, the error message could looks unfriendly. For example, below is an example for dropping tables. ``` Expected exception org.apache.spark.sql.AnalysisException to be thrown, but java.lang.reflect.InvocationTargetException was thrown. ScalaTestFailureLocation: org.apache.spark.sql.catalyst.catalog.ExternalCatalogSuite$$anonfun$14 at (ExternalCatalogSuite.scala:193) org.scalatest.exceptions.TestFailedException: Expected exception org.apache.spark.sql.AnalysisException to be thrown, but java.lang.reflect.InvocationTargetException was thrown. at org.scalatest.Assertions$class.newAssertionFailedException(Assertions.scala:496) at org.scalatest.FunSuite.newAssertionFailedException(FunSuite.scala:1555) at org.scalatest.Assertions$class.intercept(Assertions.scala:1004) at org.scalatest.FunSuite.intercept(FunSuite.scala:1555) at org.apache.spark.sql.catalyst.catalog.ExternalCatalogSuite$$anonfun$14.apply$mcV$sp(ExternalCatalogSuite.scala:193) at org.apache.spark.sql.catalyst.catalog.ExternalCatalogSuite$$anonfun$14.apply(ExternalCatalogSuite.scala:183) at org.apache.spark.sql.catalyst.catalog.ExternalCatalogSuite$$anonfun$14.apply(ExternalCatalogSuite.scala:183) at org.scalatest.Transformer$$anonfun$apply$1.apply$mcV$sp(Transformer.scala:22) at org.scalatest.OutcomeOf$class.outcomeOf(OutcomeOf.scala:85) at org.scalatest.OutcomeOf$.outcomeOf(OutcomeOf.scala:104) at org.scalatest.Transformer.apply(Transformer.scala:22) at org.scalatest.Transformer.apply(Transformer.scala:20) at org.scalatest.FunSuiteLike$$anon$1.apply(FunSuiteLike.scala:166) at org.apache.spark.SparkFunSuite.withFixture(SparkFunSuite.scala:68) at org.scalatest.FunSuiteLike$class.invokeWithFixture$1(FunSuiteLike.scala:163) at org.scalatest.FunSuiteLike$$anonfun$runTest$1.apply(FunSuiteLike.scala:175) at org.scalatest.FunSuiteLike$$anonfun$runTest$1.apply(FunSuiteLike.scala:175) at org.scalatest.SuperEngine.runTestImpl(Engine.scala:306) at org.scalatest.FunSuiteLike$class.runTest(FunSuiteLike.scala:175) at org.apache.spark.sql.catalyst.catalog.ExternalCatalogSuite.org$scalatest$BeforeAndAfterEach$$super$runTest(ExternalCatalogSuite.scala:40) at org.scalatest.BeforeAndAfterEach$class.runTest(BeforeAndAfterEach.scala:255) at org.apache.spark.sql.catalyst.catalog.ExternalCatalogSuite.runTest(ExternalCatalogSuite.scala:40) at org.scalatest.FunSuiteLike$$anonfun$runTests$1.apply(FunSuiteLike.scala:208) at org.scalatest.FunSuiteLike$$anonfun$runTests$1.apply(FunSuiteLike.scala:208) at org.scalatest.SuperEngine$$anonfun$traverseSubNodes$1$1.apply(Engine.scala:413) at org.scalatest.SuperEngine$$anonfun$traverseSubNodes$1$1.apply(Engine.scala:401) at scala.collection.immutable.List.foreach(List.scala:381) at org.scalatest.SuperEngine.traverseSubNodes$1(Engine.scala:401) at org.scalatest.SuperEngine.org$scalatest$SuperEngine$$runTestsInBranch(Engine.scala:396) at org.scalatest.SuperEngine.runTestsImpl(Engine.scala:483) at org.scalatest.FunSuiteLike$class.runTests(FunSuiteLike.scala:208) at org.scalatest.FunSuite.runTests(FunSuite.scala:1555) at org.scalatest.Suite$class.run(Suite.scala:1424) at org.scalatest.FunSuite.org$scalatest$FunSuiteLike$$super$run(FunSuite.scala:1555) at org.scalatest.FunSuiteLike$$anonfun$run$1.apply(FunSuiteLike.scala:212) at org.scalatest.FunSuiteLike$$anonfun$run$1.apply(FunSuiteLike.scala:212) at org.scalatest.SuperEngine.runImpl(Engine.scala:545) at org.scalatest.FunSuiteLike$class.run(FunSuiteLike.scala:212) at org.apache.spark.SparkFunSuite.org$scalatest$BeforeAndAfterAll$$super$run(SparkFunSuite.scala:31) at org.scalatest.BeforeAndAfterAll$class.liftedTree1$1(BeforeAndAfterAll.scala:257) at org.scalatest.BeforeAndAfterAll$class.run(BeforeAndAfterAll.scala:256) at org.apache.spark.SparkFunSuite.run(SparkFunSuite.scala:31) at org.scalatest.tools.SuiteRunner.run(SuiteRunner.scala:55) at org.scalatest.tools.Runner$$anonfun$doRunRunRunDaDoRunRun$3.apply(Runner.scala:2563) at org.scalatest.tools.Runner$$anonfun$doRunRunRunDaDoRunRun$3.apply(Runner.scala:2557) at scala.collection.immutable.List.foreach(List.scala:381) at org.scalatest.tools.Runner$.doRunRunRunDaDoRunRun(Runner.scala:2557) at org.scalatest.tools.Runner$$anonfun$runOptionallyWithPassFailReporter$2.apply(Runner.scala:1044) at org.scalatest.tools.Runner$$anonfun$runOptionallyWithPassFailReporter$2.apply(Runner.scala:1043) at org.scalatest.tools.Runner$.withClassLoaderAndDispatchReporter(Runner.scala:2722) at org.scalatest.tools.Runner$.runOptionallyWithPassFailReporter(Runner.scala:1043) at org.scalatest.tools.Runner$.run(Runner.scala:883) at org.scalatest.tools.Runner.run(Runner.scala) at org.jetbrains.plugins.scala.testingSupport.scalaTest.ScalaTestRunner.runScalaTest2(ScalaTestRunner.java:138) at org.jetbrains.plugins.scala.testingSupport.scalaTest.ScalaTestRunner.main(ScalaTestRunner.java:28) at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method) at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62) at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43) at java.lang.reflect.Method.invoke(Method.java:498) at com.intellij.rt.execution.application.AppMain.main(AppMain.java:147) Caused by: java.lang.reflect.InvocationTargetException at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method) at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62) at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43) at java.lang.reflect.Method.invoke(Method.java:498) at org.apache.spark.sql.hive.client.Shim_v0_14.dropTable(HiveShim.scala:736) at org.apache.spark.sql.hive.client.HiveClientImpl$$anonfun$dropTable$1.apply$mcV$sp(HiveClientImpl.scala:451) at org.apache.spark.sql.hive.client.HiveClientImpl$$anonfun$dropTable$1.apply(HiveClientImpl.scala:451) at org.apache.spark.sql.hive.client.HiveClientImpl$$anonfun$dropTable$1.apply(HiveClientImpl.scala:451) at org.apache.spark.sql.hive.client.HiveClientImpl$$anonfun$withHiveState$1.apply(HiveClientImpl.scala:287) at org.apache.spark.sql.hive.client.HiveClientImpl.liftedTree1$1(HiveClientImpl.scala:228) at org.apache.spark.sql.hive.client.HiveClientImpl.retryLocked(HiveClientImpl.scala:227) at org.apache.spark.sql.hive.client.HiveClientImpl.withHiveState(HiveClientImpl.scala:270) at org.apache.spark.sql.hive.client.HiveClientImpl.dropTable(HiveClientImpl.scala:450) at org.apache.spark.sql.hive.HiveExternalCatalog$$anonfun$dropTable$1.apply$mcV$sp(HiveExternalCatalog.scala:456) at org.apache.spark.sql.hive.HiveExternalCatalog$$anonfun$dropTable$1.apply(HiveExternalCatalog.scala:454) at org.apache.spark.sql.hive.HiveExternalCatalog$$anonfun$dropTable$1.apply(HiveExternalCatalog.scala:454) at org.apache.spark.sql.hive.HiveExternalCatalog.withClient(HiveExternalCatalog.scala:94) at org.apache.spark.sql.hive.HiveExternalCatalog.dropTable(HiveExternalCatalog.scala:454) at org.apache.spark.sql.catalyst.catalog.ExternalCatalogSuite$$anonfun$14$$anonfun$apply$mcV$sp$8.apply$mcV$sp(ExternalCatalogSuite.scala:194) at org.apache.spark.sql.catalyst.catalog.ExternalCatalogSuite$$anonfun$14$$anonfun$apply$mcV$sp$8.apply(ExternalCatalogSuite.scala:194) at org.apache.spark.sql.catalyst.catalog.ExternalCatalogSuite$$anonfun$14$$anonfun$apply$mcV$sp$8.apply(ExternalCatalogSuite.scala:194) at org.scalatest.Assertions$class.intercept(Assertions.scala:997) ... 57 more Caused by: org.apache.hadoop.hive.ql.metadata.HiveException: NoSuchObjectException(message:db2.unknown_table table not found) at org.apache.hadoop.hive.ql.metadata.Hive.dropTable(Hive.java:1038) ... 79 more Caused by: NoSuchObjectException(message:db2.unknown_table table not found) at org.apache.hadoop.hive.metastore.HiveMetaStore$HMSHandler.get_table_core(HiveMetaStore.java:1808) at org.apache.hadoop.hive.metastore.HiveMetaStore$HMSHandler.get_table(HiveMetaStore.java:1778) at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method) at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62) at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43) at java.lang.reflect.Method.invoke(Method.java:498) at org.apache.hadoop.hive.metastore.RetryingHMSHandler.invoke(RetryingHMSHandler.java:107) at com.sun.proxy.$Proxy10.get_table(Unknown Source) at org.apache.hadoop.hive.metastore.HiveMetaStoreClient.getTable(HiveMetaStoreClient.java:1208) at org.apache.hadoop.hive.ql.metadata.SessionHiveMetaStoreClient.getTable(SessionHiveMetaStoreClient.java:131) at org.apache.hadoop.hive.metastore.HiveMetaStoreClient.dropTable(HiveMetaStoreClient.java:952) at org.apache.hadoop.hive.metastore.HiveMetaStoreClient.dropTable(HiveMetaStoreClient.java:904) at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method) at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62) at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43) at java.lang.reflect.Method.invoke(Method.java:498) at org.apache.hadoop.hive.metastore.RetryingMetaStoreClient.invoke(RetryingMetaStoreClient.java:156) at com.sun.proxy.$Proxy11.dropTable(Unknown Source) at org.apache.hadoop.hive.ql.metadata.Hive.dropTable(Hive.java:1035) ... 79 more ``` After unwrapping the exception, the message is like ``` org.apache.hadoop.hive.ql.metadata.HiveException: NoSuchObjectException(message:db2.unknown_table table not found); org.apache.spark.sql.AnalysisException: org.apache.hadoop.hive.ql.metadata.HiveException: NoSuchObjectException(message:db2.unknown_table table not found); at org.apache.spark.sql.hive.HiveExternalCatalog.withClient(HiveExternalCatalog.scala:100) at org.apache.spark.sql.hive.HiveExternalCatalog.dropTable(HiveExternalCatalog.scala:460) at org.apache.spark.sql.catalyst.catalog.ExternalCatalogSuite$$anonfun$14.apply$mcV$sp(ExternalCatalogSuite.scala:193) at org.apache.spark.sql.catalyst.catalog.ExternalCatalogSuite$$anonfun$14.apply(ExternalCatalogSuite.scala:183) at org.apache.spark.sql.catalyst.catalog.ExternalCatalogSuite$$anonfun$14.apply(ExternalCatalogSuite.scala:183) at org.scalatest.Transformer$$anonfun$apply$1.apply$mcV$sp(Transformer.scala:22) ... ``` ### How was this patch tested? Covered by the existing test case in `test("drop table when database/table does not exist")` in `ExternalCatalogSuite`. Author: Xiao Li Closes #17265 from gatorsmile/InvocationTargetException. --- .../spark/sql/hive/HiveExternalCatalog.scala | 12 ++++++++++-- .../apache/spark/sql/hive/client/HiveShim.scala | 14 +++----------- .../spark/sql/hive/execution/HiveDDLSuite.scala | 13 ++++++------- 3 files changed, 19 insertions(+), 20 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index 78aa2bd2494f..fd633869dde5 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.hive import java.io.IOException +import java.lang.reflect.InvocationTargetException import java.net.URI import java.util @@ -68,7 +69,8 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat // Exceptions thrown by the hive client that we would like to wrap private val clientExceptions = Set( classOf[HiveException].getCanonicalName, - classOf[TException].getCanonicalName) + classOf[TException].getCanonicalName, + classOf[InvocationTargetException].getCanonicalName) /** * Whether this is an exception thrown by the hive client that should be wrapped. @@ -94,7 +96,13 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat try { body } catch { - case NonFatal(e) if isClientException(e) => + case NonFatal(exception) if isClientException(exception) => + val e = exception match { + // Since we are using shim, the exceptions thrown by the underlying method of + // Method.invoke() are wrapped by InvocationTargetException + case i: InvocationTargetException => i.getCause + case o => o + } throw new AnalysisException( e.getClass.getCanonicalName + ": " + e.getMessage, cause = Some(e)) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala index c6188fc683e7..153f1673c96f 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala @@ -733,12 +733,8 @@ private[client] class Shim_v0_14 extends Shim_v0_13 { deleteData: Boolean, ignoreIfNotExists: Boolean, purge: Boolean): Unit = { - try { - dropTableMethod.invoke(hive, dbName, tableName, deleteData: JBoolean, - ignoreIfNotExists: JBoolean, purge: JBoolean) - } catch { - case e: InvocationTargetException => throw e.getCause() - } + dropTableMethod.invoke(hive, dbName, tableName, deleteData: JBoolean, + ignoreIfNotExists: JBoolean, purge: JBoolean) } override def getMetastoreClientConnectRetryDelayMillis(conf: HiveConf): Long = { @@ -824,11 +820,7 @@ private[client] class Shim_v1_2 extends Shim_v1_1 { val dropOptions = dropOptionsClass.newInstance().asInstanceOf[Object] dropOptionsDeleteData.setBoolean(dropOptions, deleteData) dropOptionsPurge.setBoolean(dropOptions, purge) - try { - dropPartitionMethod.invoke(hive, dbName, tableName, part, dropOptions) - } catch { - case e: InvocationTargetException => throw e.getCause() - } + dropPartitionMethod.invoke(hive, dbName, tableName, part, dropOptions) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala index d29242bb47e3..d752c415c1ed 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.hive.execution import java.io.File -import java.lang.reflect.InvocationTargetException import java.net.URI import org.apache.hadoop.fs.Path @@ -1799,9 +1798,9 @@ class HiveDDLSuite assert(loc.listFiles().length >= 1) checkAnswer(spark.table("t"), Row("1") :: Nil) } else { - val e = intercept[InvocationTargetException] { + val e = intercept[AnalysisException] { spark.sql("INSERT INTO TABLE t SELECT 1") - }.getTargetException.getMessage + }.getMessage assert(e.contains("java.net.URISyntaxException: Relative path in absolute URI: a:b")) } } @@ -1836,14 +1835,14 @@ class HiveDDLSuite checkAnswer(spark.table("t1"), Row("1", "2") :: Row("1", "2017-03-03 12:13%3A14") :: Nil) } else { - val e = intercept[InvocationTargetException] { + val e = intercept[AnalysisException] { spark.sql("INSERT INTO TABLE t1 PARTITION(b=2) SELECT 1") - }.getTargetException.getMessage + }.getMessage assert(e.contains("java.net.URISyntaxException: Relative path in absolute URI: a:b")) - val e1 = intercept[InvocationTargetException] { + val e1 = intercept[AnalysisException] { spark.sql("INSERT INTO TABLE t1 PARTITION(b='2017-03-03 12:13%3A14') SELECT 1") - }.getTargetException.getMessage + }.getMessage assert(e1.contains("java.net.URISyntaxException: Relative path in absolute URI: a:b")) } } From 415f9f3423aacc395097e40427364c921a2ed7f1 Mon Sep 17 00:00:00 2001 From: Xiao Li Date: Tue, 14 Mar 2017 14:19:02 +0800 Subject: [PATCH 019/512] [SPARK-19921][SQL][TEST] Enable end-to-end testing using different Hive metastore versions. ### What changes were proposed in this pull request? To improve the quality of our Spark SQL in different Hive metastore versions, this PR is to enable end-to-end testing using different versions. This PR allows the test cases in sql/hive to pass the existing Hive client to create a SparkSession. - Since Derby does not allow concurrent connections, the pre-built Hive clients use different database from the TestHive's built-in 1.2.1 client. - Since our test cases in sql/hive only can create a single Spark context in the same JVM, the newly created SparkSession share the same spark context with the existing TestHive's corresponding SparkSession. ### How was this patch tested? Fixed the existing test cases. Author: Xiao Li Closes #17260 from gatorsmile/versionSuite. --- .../spark/sql/internal/SharedState.scala | 2 +- .../spark/sql/hive/HiveExternalCatalog.scala | 2 +- .../apache/spark/sql/hive/test/TestHive.scala | 69 +++++++++++++-- .../spark/sql/hive/client/VersionsSuite.scala | 85 +++++++++++-------- 4 files changed, 112 insertions(+), 46 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala index 86129fa87fea..1ef9d52713d9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala @@ -87,7 +87,7 @@ private[sql] class SharedState(val sparkContext: SparkContext) extends Logging { /** * A catalog that interacts with external systems. */ - val externalCatalog: ExternalCatalog = + lazy val externalCatalog: ExternalCatalog = SharedState.reflect[ExternalCatalog, SparkConf, Configuration]( SharedState.externalCatalogClassName(sparkContext.conf), sparkContext.conf, diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index fd633869dde5..33802ae62333 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -62,7 +62,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat /** * A Hive client used to interact with the metastore. */ - val client: HiveClient = { + lazy val client: HiveClient = { HiveUtils.newClientForMetadata(conf, hadoopConf) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index 076c40d45932..b63ed76967bd 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -24,23 +24,24 @@ import scala.collection.JavaConverters._ import scala.collection.mutable import scala.language.implicitConversions +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.hadoop.hive.ql.exec.FunctionRegistry import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.internal.Logging -import org.apache.spark.sql.{ExperimentalMethods, SparkSession, SQLContext} -import org.apache.spark.sql.catalyst.analysis.{Analyzer, UnresolvedRelation} -import org.apache.spark.sql.catalyst.parser.ParserInterface +import org.apache.spark.sql.{SparkSession, SQLContext} +import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation +import org.apache.spark.sql.catalyst.catalog.ExternalCatalog import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.execution.{QueryExecution, SparkPlanner} +import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.execution.command.CacheTableCommand import org.apache.spark.sql.hive._ import org.apache.spark.sql.hive.client.HiveClient import org.apache.spark.sql.internal.{SessionState, SharedState, SQLConf} import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION -import org.apache.spark.sql.streaming.StreamingQueryManager import org.apache.spark.util.{ShutdownHookManager, Utils} // SPARK-3729: Test key required to check for initialization errors with config. @@ -58,6 +59,37 @@ object TestHive .set("spark.ui.enabled", "false"))) +case class TestHiveVersion(hiveClient: HiveClient) + extends TestHiveContext(TestHive.sparkContext, hiveClient) + + +private[hive] class TestHiveExternalCatalog( + conf: SparkConf, + hadoopConf: Configuration, + hiveClient: Option[HiveClient] = None) + extends HiveExternalCatalog(conf, hadoopConf) with Logging { + + override lazy val client: HiveClient = + hiveClient.getOrElse { + HiveUtils.newClientForMetadata(conf, hadoopConf) + } +} + + +private[hive] class TestHiveSharedState( + sc: SparkContext, + hiveClient: Option[HiveClient] = None) + extends SharedState(sc) { + + override lazy val externalCatalog: ExternalCatalog = { + new TestHiveExternalCatalog( + sc.conf, + sc.hadoopConfiguration, + hiveClient) + } +} + + /** * A locally running test instance of Spark's Hive execution engine. * @@ -81,6 +113,12 @@ class TestHiveContext( this(new TestHiveSparkSession(HiveUtils.withHiveExternalCatalog(sc), loadTestTables)) } + def this(sc: SparkContext, hiveClient: HiveClient) { + this(new TestHiveSparkSession(HiveUtils.withHiveExternalCatalog(sc), + hiveClient, + loadTestTables = false)) + } + override def newSession(): TestHiveContext = { new TestHiveContext(sparkSession.newSession()) } @@ -115,7 +153,7 @@ class TestHiveContext( */ private[hive] class TestHiveSparkSession( @transient private val sc: SparkContext, - @transient private val existingSharedState: Option[SharedState], + @transient private val existingSharedState: Option[TestHiveSharedState], private val loadTestTables: Boolean) extends SparkSession(sc) with Logging { self => @@ -126,6 +164,13 @@ private[hive] class TestHiveSparkSession( loadTestTables) } + def this(sc: SparkContext, hiveClient: HiveClient, loadTestTables: Boolean) { + this( + sc, + existingSharedState = Some(new TestHiveSharedState(sc, Some(hiveClient))), + loadTestTables) + } + { // set the metastore temporary configuration val metastoreTempConf = HiveUtils.newTemporaryConfiguration(useInMemoryDerby = false) ++ Map( ConfVars.METASTORE_INTEGER_JDO_PUSHDOWN.varname -> "true", @@ -141,8 +186,8 @@ private[hive] class TestHiveSparkSession( assume(sc.conf.get(CATALOG_IMPLEMENTATION) == "hive") @transient - override lazy val sharedState: SharedState = { - existingSharedState.getOrElse(new SharedState(sc)) + override lazy val sharedState: TestHiveSharedState = { + existingSharedState.getOrElse(new TestHiveSharedState(sc)) } @transient @@ -463,6 +508,14 @@ private[hive] class TestHiveSparkSession( FunctionRegistry.getFunctionNames.asScala.filterNot(originalUDFs.contains(_)). foreach { udfName => FunctionRegistry.unregisterTemporaryUDF(udfName) } + // HDFS root scratch dir requires the write all (733) permission. For each connecting user, + // an HDFS scratch dir: ${hive.exec.scratchdir}/ is created, with + // ${hive.scratch.dir.permission}. To resolve the permission issue, the simplest way is to + // delete it. Later, it will be re-created with the right permission. + val location = new Path(sc.hadoopConfiguration.get(ConfVars.SCRATCHDIR.varname)) + val fs = location.getFileSystem(sc.hadoopConfiguration) + fs.delete(location, true) + // Some tests corrupt this value on purpose, which breaks the RESET call below. sessionState.conf.setConfString("fs.defaultFS", new File(".").toURI.toString) // It is important that we RESET first as broken hooks that might have been set could break diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala index 6025f8adbce2..cb1386111035 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala @@ -21,21 +21,20 @@ import java.io.{ByteArrayOutputStream, File, PrintStream} import java.net.URI import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.Path import org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe import org.apache.hadoop.mapred.TextInputFormat +import org.apache.spark.SparkFunSuite import org.apache.spark.internal.Logging -import org.apache.spark.sql.{AnalysisException, QueryTest, Row} +import org.apache.spark.sql.{AnalysisException, Row} import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} import org.apache.spark.sql.catalyst.analysis.{NoSuchDatabaseException, NoSuchPermanentFunctionException} import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.expressions.{AttributeReference, EqualTo, Literal} import org.apache.spark.sql.catalyst.util.quietly -import org.apache.spark.sql.hive.HiveUtils -import org.apache.spark.sql.hive.test.TestHiveSingleton -import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.hive.{HiveExternalCatalog, HiveUtils} +import org.apache.spark.sql.hive.test.TestHiveVersion import org.apache.spark.sql.types.IntegerType import org.apache.spark.sql.types.StructType import org.apache.spark.tags.ExtendedHiveTest @@ -48,11 +47,31 @@ import org.apache.spark.util.{MutableURLClassLoader, Utils} * is not fully tested. */ @ExtendedHiveTest -class VersionsSuite extends QueryTest with SQLTestUtils with TestHiveSingleton with Logging { +class VersionsSuite extends SparkFunSuite with Logging { private val clientBuilder = new HiveClientBuilder import clientBuilder.buildClient + /** + * Creates a temporary directory, which is then passed to `f` and will be deleted after `f` + * returns. + */ + protected def withTempDir(f: File => Unit): Unit = { + val dir = Utils.createTempDir().getCanonicalFile + try f(dir) finally Utils.deleteRecursively(dir) + } + + /** + * Drops table `tableName` after calling `f`. + */ + protected def withTable(tableNames: String*)(f: => Unit): Unit = { + try f finally { + tableNames.foreach { name => + versionSpark.sql(s"DROP TABLE IF EXISTS $name") + } + } + } + test("success sanity check") { val badClient = buildClient(HiveUtils.hiveExecutionVersion, new Configuration()) val db = new CatalogDatabase("default", "desc", new URI("loc"), Map()) @@ -93,6 +112,8 @@ class VersionsSuite extends QueryTest with SQLTestUtils with TestHiveSingleton w private var client: HiveClient = null + private var versionSpark: TestHiveVersion = null + versions.foreach { version => test(s"$version: create client") { client = null @@ -105,6 +126,10 @@ class VersionsSuite extends QueryTest with SQLTestUtils with TestHiveSingleton w hadoopConf.set("datanucleus.schema.autoCreateAll", "true") } client = buildClient(version, hadoopConf, HiveUtils.hiveClientConfigurations(hadoopConf)) + if (versionSpark != null) versionSpark.reset() + versionSpark = TestHiveVersion(client) + assert(versionSpark.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog].client + .version.fullVersion.startsWith(version)) } def table(database: String, tableName: String): CatalogTable = { @@ -545,22 +570,22 @@ class VersionsSuite extends QueryTest with SQLTestUtils with TestHiveSingleton w test(s"$version: CREATE TABLE AS SELECT") { withTable("tbl") { - spark.sql("CREATE TABLE tbl AS SELECT 1 AS a") - assert(spark.table("tbl").collect().toSeq == Seq(Row(1))) + versionSpark.sql("CREATE TABLE tbl AS SELECT 1 AS a") + assert(versionSpark.table("tbl").collect().toSeq == Seq(Row(1))) } } test(s"$version: Delete the temporary staging directory and files after each insert") { withTempDir { tmpDir => withTable("tab") { - spark.sql( + versionSpark.sql( s""" |CREATE TABLE tab(c1 string) |location '${tmpDir.toURI.toString}' """.stripMargin) (1 to 3).map { i => - spark.sql(s"INSERT OVERWRITE TABLE tab SELECT '$i'") + versionSpark.sql(s"INSERT OVERWRITE TABLE tab SELECT '$i'") } def listFiles(path: File): List[String] = { val dir = path.listFiles() @@ -569,7 +594,9 @@ class VersionsSuite extends QueryTest with SQLTestUtils with TestHiveSingleton w folders.flatMap(listFiles) ++: filePaths } // expect 2 files left: `.part-00000-random-uuid.crc` and `part-00000-random-uuid` - assert(listFiles(tmpDir).length == 2) + // 0.12, 0.13, 1.0 and 1.1 also has another two more files ._SUCCESS.crc and _SUCCESS + val metadataFiles = Seq("._SUCCESS.crc", "_SUCCESS") + assert(listFiles(tmpDir).filterNot(metadataFiles.contains).length == 2) } } } @@ -609,7 +636,7 @@ class VersionsSuite extends QueryTest with SQLTestUtils with TestHiveSingleton w withTable(tableName, tempTableName) { // Creates the external partitioned Avro table to be tested. - sql( + versionSpark.sql( s"""CREATE EXTERNAL TABLE $tableName |PARTITIONED BY (ds STRING) |ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.avro.AvroSerDe' @@ -622,7 +649,7 @@ class VersionsSuite extends QueryTest with SQLTestUtils with TestHiveSingleton w ) // Creates an temporary Avro table used to prepare testing Avro file. - sql( + versionSpark.sql( s"""CREATE EXTERNAL TABLE $tempTableName |ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.avro.AvroSerDe' |STORED AS @@ -634,43 +661,29 @@ class VersionsSuite extends QueryTest with SQLTestUtils with TestHiveSingleton w ) // Generates Avro data. - sql(s"INSERT OVERWRITE TABLE $tempTableName SELECT 1, STRUCT(2, 2.5)") + versionSpark.sql(s"INSERT OVERWRITE TABLE $tempTableName SELECT 1, STRUCT(2, 2.5)") // Adds generated Avro data as a new partition to the testing table. - sql(s"ALTER TABLE $tableName ADD PARTITION (ds = 'foo') LOCATION '$path/$tempTableName'") + versionSpark.sql( + s"ALTER TABLE $tableName ADD PARTITION (ds = 'foo') LOCATION '$path/$tempTableName'") // The following query fails before SPARK-13709 is fixed. This is because when reading // data from table partitions, Avro deserializer needs the Avro schema, which is defined // in table property "avro.schema.literal". However, we only initializes the deserializer // using partition properties, which doesn't include the wanted property entry. Merging // two sets of properties solves the problem. - checkAnswer( - sql(s"SELECT * FROM $tableName"), - Row(1, Row(2, 2.5D), "foo") - ) + assert(versionSpark.sql(s"SELECT * FROM $tableName").collect() === + Array(Row(1, Row(2, 2.5D), "foo"))) } } } test(s"$version: CTAS for managed data source tables") { withTable("t", "t1") { - import spark.implicits._ - - val tPath = new Path(spark.sessionState.conf.warehousePath, "t") - Seq("1").toDF("a").write.saveAsTable("t") - val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) - - assert(table.location == makeQualifiedPath(tPath.toString)) - assert(tPath.getFileSystem(spark.sessionState.newHadoopConf()).exists(tPath)) - checkAnswer(spark.table("t"), Row("1") :: Nil) - - val t1Path = new Path(spark.sessionState.conf.warehousePath, "t1") - spark.sql("create table t1 using parquet as select 2 as a") - val table1 = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t1")) - - assert(table1.location == makeQualifiedPath(t1Path.toString)) - assert(t1Path.getFileSystem(spark.sessionState.newHadoopConf()).exists(t1Path)) - checkAnswer(spark.table("t1"), Row(2) :: Nil) + versionSpark.range(1).write.saveAsTable("t") + assert(versionSpark.table("t").collect() === Array(Row(0))) + versionSpark.sql("create table t1 using parquet as select 2 as a") + assert(versionSpark.table("t1").collect() === Array(Row(2))) } } // TODO: add more tests. From f6314eab4b494bd5b5e9e41c6f582d4f22c0967a Mon Sep 17 00:00:00 2001 From: actuaryzhang Date: Tue, 14 Mar 2017 00:50:38 -0700 Subject: [PATCH 020/512] [SPARK-19391][SPARKR][ML] Tweedie GLM API for SparkR ## What changes were proposed in this pull request? Port Tweedie GLM #16344 to SparkR felixcheung yanboliang ## How was this patch tested? new test in SparkR Author: actuaryzhang Closes #16729 from actuaryzhang/sparkRTweedie. --- R/pkg/R/mllib_regression.R | 55 ++++++++++++++++--- .../tests/testthat/test_mllib_regression.R | 38 ++++++++++++- R/pkg/vignettes/sparkr-vignettes.Rmd | 19 ++++++- .../GeneralizedLinearRegressionWrapper.scala | 19 +++++-- 4 files changed, 117 insertions(+), 14 deletions(-) diff --git a/R/pkg/R/mllib_regression.R b/R/pkg/R/mllib_regression.R index 648d363f1a25..d59c890f3e5f 100644 --- a/R/pkg/R/mllib_regression.R +++ b/R/pkg/R/mllib_regression.R @@ -53,12 +53,23 @@ setClass("IsotonicRegressionModel", representation(jobj = "jobj")) #' the result of a call to a family function. Refer R family at #' \url{https://stat.ethz.ch/R-manual/R-devel/library/stats/html/family.html}. #' Currently these families are supported: \code{binomial}, \code{gaussian}, -#' \code{Gamma}, and \code{poisson}. +#' \code{Gamma}, \code{poisson} and \code{tweedie}. +#' +#' Note that there are two ways to specify the tweedie family. +#' \itemize{ +#' \item Set \code{family = "tweedie"} and specify the var.power and link.power; +#' \item When package \code{statmod} is loaded, the tweedie family is specified using the +#' family definition therein, i.e., \code{tweedie(var.power, link.power)}. +#' } #' @param tol positive convergence tolerance of iterations. #' @param maxIter integer giving the maximal number of IRLS iterations. #' @param weightCol the weight column name. If this is not set or \code{NULL}, we treat all instance #' weights as 1.0. #' @param regParam regularization parameter for L2 regularization. +#' @param var.power the power in the variance function of the Tweedie distribution which provides +#' the relationship between the variance and mean of the distribution. Only +#' applicable to the Tweedie family. +#' @param link.power the index in the power link function. Only applicable to the Tweedie family. #' @param ... additional arguments passed to the method. #' @aliases spark.glm,SparkDataFrame,formula-method #' @return \code{spark.glm} returns a fitted generalized linear model. @@ -84,14 +95,30 @@ setClass("IsotonicRegressionModel", representation(jobj = "jobj")) #' # can also read back the saved model and print #' savedModel <- read.ml(path) #' summary(savedModel) +#' +#' # fit tweedie model +#' model <- spark.glm(df, Freq ~ Sex + Age, family = "tweedie", +#' var.power = 1.2, link.power = 0) +#' summary(model) +#' +#' # use the tweedie family from statmod +#' library(statmod) +#' model <- spark.glm(df, Freq ~ Sex + Age, family = tweedie(1.2, 0)) +#' summary(model) #' } #' @note spark.glm since 2.0.0 #' @seealso \link{glm}, \link{read.ml} setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"), function(data, formula, family = gaussian, tol = 1e-6, maxIter = 25, weightCol = NULL, - regParam = 0.0) { + regParam = 0.0, var.power = 0.0, link.power = 1.0 - var.power) { + if (is.character(family)) { - family <- get(family, mode = "function", envir = parent.frame()) + # Handle when family = "tweedie" + if (tolower(family) == "tweedie") { + family <- list(family = "tweedie", link = NULL) + } else { + family <- get(family, mode = "function", envir = parent.frame()) + } } if (is.function(family)) { family <- family() @@ -100,6 +127,12 @@ setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"), print(family) stop("'family' not recognized") } + # Handle when family = statmod::tweedie() + if (tolower(family$family) == "tweedie" && !is.null(family$variance)) { + var.power <- log(family$variance(exp(1))) + link.power <- log(family$linkfun(exp(1))) + family <- list(family = "tweedie", link = NULL) + } formula <- paste(deparse(formula), collapse = "") if (!is.null(weightCol) && weightCol == "") { @@ -111,7 +144,8 @@ setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"), # For known families, Gamma is upper-cased jobj <- callJStatic("org.apache.spark.ml.r.GeneralizedLinearRegressionWrapper", "fit", formula, data@sdf, tolower(family$family), family$link, - tol, as.integer(maxIter), weightCol, regParam) + tol, as.integer(maxIter), weightCol, regParam, + as.double(var.power), as.double(link.power)) new("GeneralizedLinearRegressionModel", jobj = jobj) }) @@ -126,11 +160,13 @@ setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"), #' the result of a call to a family function. Refer R family at #' \url{https://stat.ethz.ch/R-manual/R-devel/library/stats/html/family.html}. #' Currently these families are supported: \code{binomial}, \code{gaussian}, -#' \code{Gamma}, and \code{poisson}. +#' \code{poisson}, \code{Gamma}, and \code{tweedie}. #' @param weightCol the weight column name. If this is not set or \code{NULL}, we treat all instance #' weights as 1.0. #' @param epsilon positive convergence tolerance of iterations. #' @param maxit integer giving the maximal number of IRLS iterations. +#' @param var.power the index of the power variance function in the Tweedie family. +#' @param link.power the index of the power link function in the Tweedie family. #' @return \code{glm} returns a fitted generalized linear model. #' @rdname glm #' @export @@ -145,8 +181,10 @@ setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"), #' @note glm since 1.5.0 #' @seealso \link{spark.glm} setMethod("glm", signature(formula = "formula", family = "ANY", data = "SparkDataFrame"), - function(formula, family = gaussian, data, epsilon = 1e-6, maxit = 25, weightCol = NULL) { - spark.glm(data, formula, family, tol = epsilon, maxIter = maxit, weightCol = weightCol) + function(formula, family = gaussian, data, epsilon = 1e-6, maxit = 25, weightCol = NULL, + var.power = 0.0, link.power = 1.0 - var.power) { + spark.glm(data, formula, family, tol = epsilon, maxIter = maxit, weightCol = weightCol, + var.power = var.power, link.power = link.power) }) # Returns the summary of a model produced by glm() or spark.glm(), similarly to R's summary(). @@ -172,9 +210,10 @@ setMethod("summary", signature(object = "GeneralizedLinearRegressionModel"), deviance <- callJMethod(jobj, "rDeviance") df.null <- callJMethod(jobj, "rResidualDegreeOfFreedomNull") df.residual <- callJMethod(jobj, "rResidualDegreeOfFreedom") - aic <- callJMethod(jobj, "rAic") iter <- callJMethod(jobj, "rNumIterations") family <- callJMethod(jobj, "rFamily") + aic <- callJMethod(jobj, "rAic") + if (family == "tweedie" && aic == 0) aic <- NA deviance.resid <- if (is.loaded) { NULL } else { diff --git a/R/pkg/inst/tests/testthat/test_mllib_regression.R b/R/pkg/inst/tests/testthat/test_mllib_regression.R index 81a5bdc41492..3e9ad7719807 100644 --- a/R/pkg/inst/tests/testthat/test_mllib_regression.R +++ b/R/pkg/inst/tests/testthat/test_mllib_regression.R @@ -77,6 +77,24 @@ test_that("spark.glm and predict", { out <- capture.output(print(summary(model))) expect_true(any(grepl("Dispersion parameter for gamma family", out))) + # tweedie family + model <- spark.glm(training, Sepal_Width ~ Sepal_Length + Species, + family = "tweedie", var.power = 1.2, link.power = 0.0) + prediction <- predict(model, training) + expect_equal(typeof(take(select(prediction, "prediction"), 1)$prediction), "double") + vals <- collect(select(prediction, "prediction")) + + # manual calculation of the R predicted values to avoid dependence on statmod + #' library(statmod) + #' rModel <- glm(Sepal.Width ~ Sepal.Length + Species, data = iris, + #' family = tweedie(var.power = 1.2, link.power = 0.0)) + #' print(coef(rModel)) + + rCoef <- c(0.6455409, 0.1169143, -0.3224752, -0.3282174) + rVals <- exp(as.numeric(model.matrix(Sepal.Width ~ Sepal.Length + Species, + data = iris) %*% rCoef)) + expect_true(all(abs(rVals - vals) < 1e-5), rVals - vals) + # Test stats::predict is working x <- rnorm(15) y <- x + rnorm(15) @@ -233,7 +251,7 @@ test_that("glm and predict", { training <- suppressWarnings(createDataFrame(iris)) # gaussian family model <- glm(Sepal_Width ~ Sepal_Length + Species, data = training) - prediction <- predict(model, training) + prediction <- predict(model, training) expect_equal(typeof(take(select(prediction, "prediction"), 1)$prediction), "double") vals <- collect(select(prediction, "prediction")) rVals <- predict(glm(Sepal.Width ~ Sepal.Length + Species, data = iris), iris) @@ -249,6 +267,24 @@ test_that("glm and predict", { data = iris, family = poisson(link = identity)), iris)) expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals) + # tweedie family + model <- glm(Sepal_Width ~ Sepal_Length + Species, data = training, + family = "tweedie", var.power = 1.2, link.power = 0.0) + prediction <- predict(model, training) + expect_equal(typeof(take(select(prediction, "prediction"), 1)$prediction), "double") + vals <- collect(select(prediction, "prediction")) + + # manual calculation of the R predicted values to avoid dependence on statmod + #' library(statmod) + #' rModel <- glm(Sepal.Width ~ Sepal.Length + Species, data = iris, + #' family = tweedie(var.power = 1.2, link.power = 0.0)) + #' print(coef(rModel)) + + rCoef <- c(0.6455409, 0.1169143, -0.3224752, -0.3282174) + rVals <- exp(as.numeric(model.matrix(Sepal.Width ~ Sepal.Length + Species, + data = iris) %*% rCoef)) + expect_true(all(abs(rVals - vals) < 1e-5), rVals - vals) + # Test stats::predict is working x <- rnorm(15) y <- x + rnorm(15) diff --git a/R/pkg/vignettes/sparkr-vignettes.Rmd b/R/pkg/vignettes/sparkr-vignettes.Rmd index 43c255cff302..a6ff650c33fe 100644 --- a/R/pkg/vignettes/sparkr-vignettes.Rmd +++ b/R/pkg/vignettes/sparkr-vignettes.Rmd @@ -672,6 +672,7 @@ gaussian | identity, log, inverse binomial | logit, probit, cloglog (complementary log-log) poisson | log, identity, sqrt gamma | inverse, identity, log +tweedie | power link function There are three ways to specify the `family` argument. @@ -679,7 +680,11 @@ There are three ways to specify the `family` argument. * Family function, e.g. `family = binomial`. -* Result returned by a family function, e.g. `family = poisson(link = log)` +* Result returned by a family function, e.g. `family = poisson(link = log)`. + +* Note that there are two ways to specify the tweedie family: + a) Set `family = "tweedie"` and specify the `var.power` and `link.power` + b) When package `statmod` is loaded, the tweedie family is specified using the family definition therein, i.e., `tweedie()`. For more information regarding the families and their link functions, see the Wikipedia page [Generalized Linear Model](https://en.wikipedia.org/wiki/Generalized_linear_model). @@ -695,6 +700,18 @@ gaussianFitted <- predict(gaussianGLM, carsDF) head(select(gaussianFitted, "model", "prediction", "mpg", "wt", "hp")) ``` +The following is the same fit using the tweedie family: +```{r} +tweedieGLM1 <- spark.glm(carsDF, mpg ~ wt + hp, family = "tweedie", var.power = 0.0) +summary(tweedieGLM1) +``` +We can try other distributions in the tweedie family, for example, a compound Poisson distribution with a log link: +```{r} +tweedieGLM2 <- spark.glm(carsDF, mpg ~ wt + hp, family = "tweedie", + var.power = 1.2, link.power = 0.0) +summary(tweedieGLM2) +``` + #### Isotonic Regression `spark.isoreg` fits an [Isotonic Regression](https://en.wikipedia.org/wiki/Isotonic_regression) model against a `SparkDataFrame`. It solves a weighted univariate a regression problem under a complete order constraint. Specifically, given a set of real observed responses $y_1, \ldots, y_n$, corresponding real features $x_1, \ldots, x_n$, and optionally positive weights $w_1, \ldots, w_n$, we want to find a monotone (piecewise linear) function $f$ to minimize diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala index cbd6cd1c7933..c49416b24018 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala @@ -71,7 +71,9 @@ private[r] object GeneralizedLinearRegressionWrapper tol: Double, maxIter: Int, weightCol: String, - regParam: Double): GeneralizedLinearRegressionWrapper = { + regParam: Double, + variancePower: Double, + linkPower: Double): GeneralizedLinearRegressionWrapper = { val rFormula = new RFormula().setFormula(formula) checkDataColumns(rFormula, data) val rFormulaModel = rFormula.fit(data) @@ -83,13 +85,17 @@ private[r] object GeneralizedLinearRegressionWrapper // assemble and fit the pipeline val glr = new GeneralizedLinearRegression() .setFamily(family) - .setLink(link) .setFitIntercept(rFormula.hasIntercept) .setTol(tol) .setMaxIter(maxIter) .setRegParam(regParam) .setFeaturesCol(rFormula.getFeaturesCol) - + // set variancePower and linkPower if family is tweedie; otherwise, set link function + if (family.toLowerCase == "tweedie") { + glr.setVariancePower(variancePower).setLinkPower(linkPower) + } else { + glr.setLink(link) + } if (weightCol != null) glr.setWeightCol(weightCol) val pipeline = new Pipeline() @@ -145,7 +151,12 @@ private[r] object GeneralizedLinearRegressionWrapper val rDeviance: Double = summary.deviance val rResidualDegreeOfFreedomNull: Long = summary.residualDegreeOfFreedomNull val rResidualDegreeOfFreedom: Long = summary.residualDegreeOfFreedom - val rAic: Double = summary.aic + val rAic: Double = if (family.toLowerCase == "tweedie" && + !Array(0.0, 1.0, 2.0).exists(x => math.abs(x - variancePower) < 1e-8)) { + 0.0 + } else { + summary.aic + } val rNumIterations: Int = summary.numIterations new GeneralizedLinearRegressionWrapper(pipeline, rFeatures, rCoefficients, rDispersion, From 4ce970d71488c7de6025ef925f75b8b92a5a6a79 Mon Sep 17 00:00:00 2001 From: Nattavut Sutyanyong Date: Tue, 14 Mar 2017 10:37:10 +0100 Subject: [PATCH 021/512] [SPARK-18874][SQL] First phase: Deferring the correlated predicate pull up to Optimizer phase ## What changes were proposed in this pull request? Currently Analyzer as part of ResolveSubquery, pulls up the correlated predicates to its originating SubqueryExpression. The subquery plan is then transformed to remove the correlated predicates after they are moved up to the outer plan. In this PR, the task of pulling up correlated predicates is deferred to Optimizer. This is the initial work that will allow us to support the form of correlated subqueries that we don't support today. The design document from nsyca can be found in the following link : [DesignDoc](https://docs.google.com/document/d/1QDZ8JwU63RwGFS6KVF54Rjj9ZJyK33d49ZWbjFBaIgU/edit#) The brief description of code changes (hopefully to aid with code review) can be be found in the following link: [CodeChanges](https://docs.google.com/document/d/18mqjhL9V1An-tNta7aVE13HkALRZ5GZ24AATA-Vqqf0/edit#) ## How was this patch tested? The test case PRs were submitted earlier using. [16337](https://github.com/apache/spark/pull/16337) [16759](https://github.com/apache/spark/pull/16759) [16841](https://github.com/apache/spark/pull/16841) [16915](https://github.com/apache/spark/pull/16915) [16798](https://github.com/apache/spark/pull/16798) [16712](https://github.com/apache/spark/pull/16712) [16710](https://github.com/apache/spark/pull/16710) [16760](https://github.com/apache/spark/pull/16760) [16802](https://github.com/apache/spark/pull/16802) Author: Dilip Biswal Closes #16954 from dilipbiswal/SPARK-18874. --- .../sql/catalyst/analysis/Analyzer.scala | 314 ++++++++++-------- .../sql/catalyst/analysis/CheckAnalysis.scala | 40 +-- .../sql/catalyst/analysis/TypeCoercion.scala | 130 ++++++-- .../sql/catalyst/expressions/predicates.scala | 43 ++- .../sql/catalyst/expressions/subquery.scala | 256 ++++++++++---- .../sql/catalyst/optimizer/Optimizer.scala | 4 +- .../sql/catalyst/optimizer/subquery.scala | 159 ++++++++- .../analysis/AnalysisErrorSuite.scala | 11 +- .../analysis/ResolveSubquerySuite.scala | 2 +- .../spark/sql/catalyst/plans/PlanTest.scala | 2 - .../apache/spark/sql/execution/subquery.scala | 3 - .../invalid-correlation.sql.out | 4 +- .../org/apache/spark/sql/SubquerySuite.scala | 7 +- 13 files changed, 675 insertions(+), 300 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 93666f14958e..a3764d8c843d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -21,12 +21,13 @@ import scala.annotation.tailrec import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.{CatalystConf, ScalaReflection, SimpleCatalystConf, TableIdentifier} +import org.apache.spark.sql.catalyst._ import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.encoders.OuterScopes import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.objects.NewInstance +import org.apache.spark.sql.catalyst.expressions.SubExprUtils._ import org.apache.spark.sql.catalyst.optimizer.BooleanSimplification import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, _} @@ -162,6 +163,8 @@ class Analyzer( FixNullability), Batch("ResolveTimeZone", Once, ResolveTimeZone), + Batch("Subquery", Once, + UpdateOuterReferences), Batch("Cleanup", fixedPoint, CleanupAliases) ) @@ -710,13 +713,72 @@ class Analyzer( } transformUp { case other => other transformExpressions { case a: Attribute => - attributeRewrites.get(a).getOrElse(a).withQualifier(a.qualifier) + dedupAttr(a, attributeRewrites) + case s: SubqueryExpression => + s.withNewPlan(dedupOuterReferencesInSubquery(s.plan, attributeRewrites)) } } newRight } } + private def dedupAttr(attr: Attribute, attrMap: AttributeMap[Attribute]): Attribute = { + attrMap.get(attr).getOrElse(attr).withQualifier(attr.qualifier) + } + + /** + * The outer plan may have been de-duplicated and the function below updates the + * outer references to refer to the de-duplicated attributes. + * + * For example (SQL): + * {{{ + * SELECT * FROM t1 + * INTERSECT + * SELECT * FROM t1 + * WHERE EXISTS (SELECT 1 + * FROM t2 + * WHERE t1.c1 = t2.c1) + * }}} + * Plan before resolveReference rule. + * 'Intersect + * :- Project [c1#245, c2#246] + * : +- SubqueryAlias t1 + * : +- Relation[c1#245,c2#246] parquet + * +- 'Project [*] + * +- Filter exists#257 [c1#245] + * : +- Project [1 AS 1#258] + * : +- Filter (outer(c1#245) = c1#251) + * : +- SubqueryAlias t2 + * : +- Relation[c1#251,c2#252] parquet + * +- SubqueryAlias t1 + * +- Relation[c1#245,c2#246] parquet + * Plan after the resolveReference rule. + * Intersect + * :- Project [c1#245, c2#246] + * : +- SubqueryAlias t1 + * : +- Relation[c1#245,c2#246] parquet + * +- Project [c1#259, c2#260] + * +- Filter exists#257 [c1#259] + * : +- Project [1 AS 1#258] + * : +- Filter (outer(c1#259) = c1#251) => Updated + * : +- SubqueryAlias t2 + * : +- Relation[c1#251,c2#252] parquet + * +- SubqueryAlias t1 + * +- Relation[c1#259,c2#260] parquet => Outer plan's attributes are de-duplicated. + */ + private def dedupOuterReferencesInSubquery( + plan: LogicalPlan, + attrMap: AttributeMap[Attribute]): LogicalPlan = { + plan transformDown { case currentFragment => + currentFragment transformExpressions { + case OuterReference(a: Attribute) => + OuterReference(dedupAttr(a, attrMap)) + case s: SubqueryExpression => + s.withNewPlan(dedupOuterReferencesInSubquery(s.plan, attrMap)) + } + } + } + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case p: LogicalPlan if !p.childrenResolved => p @@ -1132,28 +1194,21 @@ class Analyzer( } /** - * Pull out all (outer) correlated predicates from a given subquery. This method removes the - * correlated predicates from subquery [[Filter]]s and adds the references of these predicates - * to all intermediate [[Project]] and [[Aggregate]] clauses (if they are missing) in order to - * be able to evaluate the predicates at the top level. - * - * This method returns the rewritten subquery and correlated predicates. + * Validates to make sure the outer references appearing inside the subquery + * are legal. This function also returns the list of expressions + * that contain outer references. These outer references would be kept as children + * of subquery expressions by the caller of this function. */ - private def pullOutCorrelatedPredicates(sub: LogicalPlan): (LogicalPlan, Seq[Expression]) = { - val predicateMap = scala.collection.mutable.Map.empty[LogicalPlan, Seq[Expression]] + private def checkAndGetOuterReferences(sub: LogicalPlan): Seq[Expression] = { + val outerReferences = ArrayBuffer.empty[Expression] // Make sure a plan's subtree does not contain outer references def failOnOuterReferenceInSubTree(p: LogicalPlan): Unit = { - if (p.collectFirst(predicateMap).nonEmpty) { + if (hasOuterReferences(p)) { failAnalysis(s"Accessing outer query column is not allowed in:\n$p") } } - // Helper function for locating outer references. - def containsOuter(e: Expression): Boolean = { - e.find(_.isInstanceOf[OuterReference]).isDefined - } - // Make sure a plan's expressions do not contain outer references def failOnOuterReference(p: LogicalPlan): Unit = { if (p.expressions.exists(containsOuter)) { @@ -1194,20 +1249,11 @@ class Analyzer( } } - /** Determine which correlated predicate references are missing from this plan. */ - def missingReferences(p: LogicalPlan): AttributeSet = { - val localPredicateReferences = p.collect(predicateMap) - .flatten - .map(_.references) - .reduceOption(_ ++ _) - .getOrElse(AttributeSet.empty) - localPredicateReferences -- p.outputSet - } - var foundNonEqualCorrelatedPred : Boolean = false - // Simplify the predicates before pulling them out. - val transformed = BooleanSimplification(sub) transformUp { + // Simplify the predicates before validating any unsupported correlation patterns + // in the plan. + BooleanSimplification(sub).foreachUp { // Whitelist operators allowed in a correlated subquery // There are 4 categories: @@ -1229,80 +1275,48 @@ class Analyzer( // Category 1: // BroadcastHint, Distinct, LeafNode, Repartition, and SubqueryAlias - case p: BroadcastHint => - p - case p: Distinct => - p - case p: LeafNode => - p - case p: Repartition => - p - case p: SubqueryAlias => - p + case _: BroadcastHint | _: Distinct | _: LeafNode | _: Repartition | _: SubqueryAlias => // Category 2: // These operators can be anywhere in a correlated subquery. // so long as they do not host outer references in the operators. - case p: Sort => - failOnOuterReference(p) - p - case p: RepartitionByExpression => - failOnOuterReference(p) - p + case s: Sort => + failOnOuterReference(s) + case r: RepartitionByExpression => + failOnOuterReference(r) // Category 3: // Filter is one of the two operators allowed to host correlated expressions. // The other operator is Join. Filter can be anywhere in a correlated subquery. - case f @ Filter(cond, child) => + case f: Filter => // Find all predicates with an outer reference. - val (correlated, local) = splitConjunctivePredicates(cond).partition(containsOuter) + val (correlated, _) = splitConjunctivePredicates(f.condition).partition(containsOuter) // Find any non-equality correlated predicates foundNonEqualCorrelatedPred = foundNonEqualCorrelatedPred || correlated.exists { case _: EqualTo | _: EqualNullSafe => false case _ => true } - - // Rewrite the filter without the correlated predicates if any. - correlated match { - case Nil => f - case xs if local.nonEmpty => - val newFilter = Filter(local.reduce(And), child) - predicateMap += newFilter -> xs - newFilter - case xs => - predicateMap += child -> xs - child - } + // The aggregate expressions are treated in a special way by getOuterReferences. If the + // aggregate expression contains only outer reference attributes then the entire aggregate + // expression is isolated as an OuterReference. + // i.e min(OuterReference(b)) => OuterReference(min(b)) + outerReferences ++= getOuterReferences(correlated) // Project cannot host any correlated expressions // but can be anywhere in a correlated subquery. - case p @ Project(expressions, child) => + case p: Project => failOnOuterReference(p) - val referencesToAdd = missingReferences(p) - if (referencesToAdd.nonEmpty) { - Project(expressions ++ referencesToAdd, child) - } else { - p - } - // Aggregate cannot host any correlated expressions // It can be on a correlation path if the correlation contains // only equality correlated predicates. // It cannot be on a correlation path if the correlation has // non-equality correlated predicates. - case a @ Aggregate(grouping, expressions, child) => + case a: Aggregate => failOnOuterReference(a) failOnNonEqualCorrelatedPredicate(foundNonEqualCorrelatedPred, a) - val referencesToAdd = missingReferences(a) - if (referencesToAdd.nonEmpty) { - Aggregate(grouping ++ referencesToAdd, expressions ++ referencesToAdd, child) - } else { - a - } - // Join can host correlated expressions. case j @ Join(left, right, joinType, _) => joinType match { @@ -1332,7 +1346,6 @@ class Analyzer( case _ => failOnOuterReferenceInSubTree(j) } - j // Generator with join=true, i.e., expressed with // LATERAL VIEW [OUTER], similar to inner join, @@ -1340,9 +1353,8 @@ class Analyzer( // but must not host any outer references. // Note: // Generator with join=false is treated as Category 4. - case p @ Generate(generator, true, _, _, _, _) => - failOnOuterReference(p) - p + case g: Generate if g.join => + failOnOuterReference(g) // Category 4: Any other operators not in the above 3 categories // cannot be on a correlation path, that is they are allowed only @@ -1350,54 +1362,17 @@ class Analyzer( // are not allowed to have any correlated expressions. case p => failOnOuterReferenceInSubTree(p) - p } - (transformed, predicateMap.values.flatten.toSeq) + outerReferences } /** - * Rewrite the subquery in a safe way by preventing that the subquery and the outer use the same - * attributes. - */ - private def rewriteSubQuery( - sub: LogicalPlan, - outer: Seq[LogicalPlan]): (LogicalPlan, Seq[Expression]) = { - // Pull out the tagged predicates and rewrite the subquery in the process. - val (basePlan, baseConditions) = pullOutCorrelatedPredicates(sub) - - // Make sure the inner and the outer query attributes do not collide. - val outputSet = outer.map(_.outputSet).reduce(_ ++ _) - val duplicates = basePlan.outputSet.intersect(outputSet) - val (plan, deDuplicatedConditions) = if (duplicates.nonEmpty) { - val aliasMap = AttributeMap(duplicates.map { dup => - dup -> Alias(dup, dup.toString)() - }.toSeq) - val aliasedExpressions = basePlan.output.map { ref => - aliasMap.getOrElse(ref, ref) - } - val aliasedProjection = Project(aliasedExpressions, basePlan) - val aliasedConditions = baseConditions.map(_.transform { - case ref: Attribute => aliasMap.getOrElse(ref, ref).toAttribute - }) - (aliasedProjection, aliasedConditions) - } else { - (basePlan, baseConditions) - } - // Remove outer references from the correlated predicates. We wait with extracting - // these until collisions between the inner and outer query attributes have been - // solved. - val conditions = deDuplicatedConditions.map(_.transform { - case OuterReference(ref) => ref - }) - (plan, conditions) - } - - /** - * Resolve and rewrite a subquery. The subquery is resolved using its outer plans. This method + * Resolves the subquery. The subquery is resolved using its outer plans. This method * will resolve the subquery by alternating between the regular analyzer and by applying the * resolveOuterReferences rule. * - * All correlated conditions are pulled out of the subquery as soon as the subquery is resolved. + * Outer references from the correlated predicates are updated as children of + * Subquery expression. */ private def resolveSubQuery( e: SubqueryExpression, @@ -1420,7 +1395,8 @@ class Analyzer( } } while (!current.resolved && !current.fastEquals(previous)) - // Step 2: Pull out the predicates if the plan is resolved. + // Step 2: If the subquery plan is fully resolved, pull the outer references and record + // them as children of SubqueryExpression. if (current.resolved) { // Make sure the resolved query has the required number of output columns. This is only // needed for Scalar and IN subqueries. @@ -1428,34 +1404,37 @@ class Analyzer( failAnalysis(s"The number of columns in the subquery (${current.output.size}) " + s"does not match the required number of columns ($requiredColumns)") } - // Pullout predicates and construct a new plan. - f.tupled(rewriteSubQuery(current, plans)) + // Validate the outer reference and record the outer references as children of + // subquery expression. + f(current, checkAndGetOuterReferences(current)) } else { e.withNewPlan(current) } } /** - * Resolve and rewrite all subqueries in a LogicalPlan. This method transforms IN and EXISTS - * expressions into PredicateSubquery expression once the are resolved. + * Resolves the subquery. Apart of resolving the subquery and outer references (if any) + * in the subquery plan, the children of subquery expression are updated to record the + * outer references. This is needed to make sure + * (1) The column(s) referred from the outer query are not pruned from the plan during + * optimization. + * (2) Any aggregate expression(s) that reference outer attributes are pushed down to + * outer plan to get evaluated. */ private def resolveSubQueries(plan: LogicalPlan, plans: Seq[LogicalPlan]): LogicalPlan = { plan transformExpressions { case s @ ScalarSubquery(sub, _, exprId) if !sub.resolved => resolveSubQuery(s, plans, 1)(ScalarSubquery(_, _, exprId)) - case e @ Exists(sub, exprId) => - resolveSubQuery(e, plans)(PredicateSubquery(_, _, nullAware = false, exprId)) - case In(e, Seq(l @ ListQuery(_, exprId))) if e.resolved => + case e @ Exists(sub, _, exprId) if !sub.resolved => + resolveSubQuery(e, plans)(Exists(_, _, exprId)) + case In(value, Seq(l @ ListQuery(sub, _, exprId))) if value.resolved && !sub.resolved => // Get the left hand side expressions. - val expressions = e match { + val expressions = value match { case cns : CreateNamedStruct => cns.valExprs case expr => Seq(expr) } - resolveSubQuery(l, plans, expressions.size) { (rewrite, conditions) => - // Construct the IN conditions. - val inConditions = expressions.zip(rewrite.output).map(EqualTo.tupled) - PredicateSubquery(rewrite, inConditions ++ conditions, nullAware = true, exprId) - } + val expr = resolveSubQuery(l, plans, expressions.size)(ListQuery(_, _, exprId)) + In(value, Seq(expr)) } } @@ -2353,6 +2332,11 @@ class Analyzer( override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveExpressions { case e: TimeZoneAwareExpression if e.timeZoneId.isEmpty => e.withTimeZone(conf.sessionLocalTimeZone) + // Casts could be added in the subquery plan through the rule TypeCoercion while coercing + // the types between the value expression and list query expression of IN expression. + // We need to subject the subquery plan through ResolveTimeZone again to setup timezone + // information for time zone aware expressions. + case e: ListQuery => e.withNewPlan(apply(e.plan)) } } } @@ -2533,3 +2517,67 @@ object ResolveCreateNamedStruct extends Rule[LogicalPlan] { CreateNamedStruct(children.toList) } } + +/** + * The aggregate expressions from subquery referencing outer query block are pushed + * down to the outer query block for evaluation. This rule below updates such outer references + * as AttributeReference referring attributes from the parent/outer query block. + * + * For example (SQL): + * {{{ + * SELECT l.a FROM l GROUP BY 1 HAVING EXISTS (SELECT 1 FROM r WHERE r.d < min(l.b)) + * }}} + * Plan before the rule. + * Project [a#226] + * +- Filter exists#245 [min(b#227)#249] + * : +- Project [1 AS 1#247] + * : +- Filter (d#238 < min(outer(b#227))) <----- + * : +- SubqueryAlias r + * : +- Project [_1#234 AS c#237, _2#235 AS d#238] + * : +- LocalRelation [_1#234, _2#235] + * +- Aggregate [a#226], [a#226, min(b#227) AS min(b#227)#249] + * +- SubqueryAlias l + * +- Project [_1#223 AS a#226, _2#224 AS b#227] + * +- LocalRelation [_1#223, _2#224] + * Plan after the rule. + * Project [a#226] + * +- Filter exists#245 [min(b#227)#249] + * : +- Project [1 AS 1#247] + * : +- Filter (d#238 < outer(min(b#227)#249)) <----- + * : +- SubqueryAlias r + * : +- Project [_1#234 AS c#237, _2#235 AS d#238] + * : +- LocalRelation [_1#234, _2#235] + * +- Aggregate [a#226], [a#226, min(b#227) AS min(b#227)#249] + * +- SubqueryAlias l + * +- Project [_1#223 AS a#226, _2#224 AS b#227] + * +- LocalRelation [_1#223, _2#224] + */ +object UpdateOuterReferences extends Rule[LogicalPlan] { + private def stripAlias(expr: Expression): Expression = expr match { case a: Alias => a.child } + + private def updateOuterReferenceInSubquery( + plan: LogicalPlan, + refExprs: Seq[Expression]): LogicalPlan = { + plan transformAllExpressions { case e => + val outerAlias = + refExprs.find(stripAlias(_).semanticEquals(stripOuterReference(e))) + outerAlias match { + case Some(a: Alias) => OuterReference(a.toAttribute) + case _ => e + } + } + } + + def apply(plan: LogicalPlan): LogicalPlan = { + plan transform { + case f @ Filter(_, a: Aggregate) if f.resolved => + f transformExpressions { + case s: SubqueryExpression if s.children.nonEmpty => + // Collect the aliases from output of aggregate. + val outerAliases = a.aggregateExpressions collect { case a: Alias => a } + // Update the subquery plan to record the OuterReference to point to outer query plan. + s.withNewPlan(updateOuterReferenceInSubquery(s.plan, outerAliases)) + } + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index d32fbeb4e91e..da0c6b098f5c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression +import org.apache.spark.sql.catalyst.expressions.SubExprUtils._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.types._ @@ -133,10 +134,8 @@ trait CheckAnalysis extends PredicateHelper { if (conditions.isEmpty && query.output.size != 1) { failAnalysis( s"Scalar subquery must return only one column, but got ${query.output.size}") - } else if (conditions.nonEmpty) { - // Collect the columns from the subquery for further checking. - var subqueryColumns = conditions.flatMap(_.references).filter(query.output.contains) - + } + else if (conditions.nonEmpty) { def checkAggregate(agg: Aggregate): Unit = { // Make sure correlated scalar subqueries contain one row for every outer row by // enforcing that they are aggregates containing exactly one aggregate expression. @@ -152,6 +151,9 @@ trait CheckAnalysis extends PredicateHelper { // SPARK-18504/SPARK-18814: Block cases where GROUP BY columns // are not part of the correlated columns. val groupByCols = AttributeSet(agg.groupingExpressions.flatMap(_.references)) + // Collect the local references from the correlated predicate in the subquery. + val subqueryColumns = getCorrelatedPredicates(query).flatMap(_.references) + .filterNot(conditions.flatMap(_.references).contains) val correlatedCols = AttributeSet(subqueryColumns) val invalidCols = groupByCols -- correlatedCols // GROUP BY columns must be a subset of columns in the predicates @@ -167,17 +169,7 @@ trait CheckAnalysis extends PredicateHelper { // For projects, do the necessary mapping and skip to its child. def cleanQuery(p: LogicalPlan): LogicalPlan = p match { case s: SubqueryAlias => cleanQuery(s.child) - case p: Project => - // SPARK-18814: Map any aliases to their AttributeReference children - // for the checking in the Aggregate operators below this Project. - subqueryColumns = subqueryColumns.map { - xs => p.projectList.collectFirst { - case e @ Alias(child : AttributeReference, _) if e.exprId == xs.exprId => - child - }.getOrElse(xs) - } - - cleanQuery(p.child) + case p: Project => cleanQuery(p.child) case child => child } @@ -211,14 +203,9 @@ trait CheckAnalysis extends PredicateHelper { s"filter expression '${f.condition.sql}' " + s"of type ${f.condition.dataType.simpleString} is not a boolean.") - case Filter(condition, _) => - splitConjunctivePredicates(condition).foreach { - case _: PredicateSubquery | Not(_: PredicateSubquery) => - case e if PredicateSubquery.hasNullAwarePredicateWithinNot(e) => - failAnalysis(s"Null-aware predicate sub-queries cannot be used in nested" + - s" conditions: $e") - case e => - } + case Filter(condition, _) if hasNullAwarePredicateWithinNot(condition) => + failAnalysis("Null-aware predicate sub-queries cannot be used in nested " + + s"conditions: $condition") case j @ Join(_, _, _, Some(condition)) if condition.dataType != BooleanType => failAnalysis( @@ -306,8 +293,11 @@ trait CheckAnalysis extends PredicateHelper { s"Correlated scalar sub-queries can only be used in a Filter/Aggregate/Project: $p") } - case p if p.expressions.exists(PredicateSubquery.hasPredicateSubquery) => - failAnalysis(s"Predicate sub-queries can only be used in a Filter: $p") + case p if p.expressions.exists(SubqueryExpression.hasInOrExistsSubquery) => + p match { + case _: Filter => // Ok + case _ => failAnalysis(s"Predicate sub-queries can only be used in a Filter: $p") + } case _: Union | _: SetOperation if operator.children.length > 1 => def dataTypes(plan: LogicalPlan): Seq[DataType] = plan.output.map(_.dataType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 2c00957bd6af..768897dc0713 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -108,6 +108,28 @@ object TypeCoercion { case _ => None } + /** + * This function determines the target type of a comparison operator when one operand + * is a String and the other is not. It also handles when one op is a Date and the + * other is a Timestamp by making the target type to be String. + */ + val findCommonTypeForBinaryComparison: (DataType, DataType) => Option[DataType] = { + // We should cast all relative timestamp/date/string comparison into string comparisons + // This behaves as a user would expect because timestamp strings sort lexicographically. + // i.e. TimeStamp(2013-01-01 00:00 ...) < "2014" = true + case (StringType, DateType) => Some(StringType) + case (DateType, StringType) => Some(StringType) + case (StringType, TimestampType) => Some(StringType) + case (TimestampType, StringType) => Some(StringType) + case (TimestampType, DateType) => Some(StringType) + case (DateType, TimestampType) => Some(StringType) + case (StringType, NullType) => Some(StringType) + case (NullType, StringType) => Some(StringType) + case (l: StringType, r: AtomicType) if r != StringType => Some(r) + case (l: AtomicType, r: StringType) if (l != StringType) => Some(l) + case (l, r) => None + } + /** * Case 2 type widening (see the classdoc comment above for TypeCoercion). * @@ -305,6 +327,14 @@ object TypeCoercion { * Promotes strings that appear in arithmetic expressions. */ object PromoteStrings extends Rule[LogicalPlan] { + private def castExpr(expr: Expression, targetType: DataType): Expression = { + (expr.dataType, targetType) match { + case (NullType, dt) => Literal.create(null, targetType) + case (l, dt) if (l != dt) => Cast(expr, targetType) + case _ => expr + } + } + def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e @@ -321,37 +351,10 @@ object TypeCoercion { case p @ Equality(left @ TimestampType(), right @ StringType()) => p.makeCopy(Array(left, Cast(right, TimestampType))) - // We should cast all relative timestamp/date/string comparison into string comparisons - // This behaves as a user would expect because timestamp strings sort lexicographically. - // i.e. TimeStamp(2013-01-01 00:00 ...) < "2014" = true - case p @ BinaryComparison(left @ StringType(), right @ DateType()) => - p.makeCopy(Array(left, Cast(right, StringType))) - case p @ BinaryComparison(left @ DateType(), right @ StringType()) => - p.makeCopy(Array(Cast(left, StringType), right)) - case p @ BinaryComparison(left @ StringType(), right @ TimestampType()) => - p.makeCopy(Array(left, Cast(right, StringType))) - case p @ BinaryComparison(left @ TimestampType(), right @ StringType()) => - p.makeCopy(Array(Cast(left, StringType), right)) - - // Comparisons between dates and timestamps. - case p @ BinaryComparison(left @ TimestampType(), right @ DateType()) => - p.makeCopy(Array(Cast(left, StringType), Cast(right, StringType))) - case p @ BinaryComparison(left @ DateType(), right @ TimestampType()) => - p.makeCopy(Array(Cast(left, StringType), Cast(right, StringType))) - - // Checking NullType - case p @ BinaryComparison(left @ StringType(), right @ NullType()) => - p.makeCopy(Array(left, Literal.create(null, StringType))) - case p @ BinaryComparison(left @ NullType(), right @ StringType()) => - p.makeCopy(Array(Literal.create(null, StringType), right)) - - // When compare string with atomic type, case string to that type. - case p @ BinaryComparison(left @ StringType(), right @ AtomicType()) - if right.dataType != StringType => - p.makeCopy(Array(Cast(left, right.dataType), right)) - case p @ BinaryComparison(left @ AtomicType(), right @ StringType()) - if left.dataType != StringType => - p.makeCopy(Array(left, Cast(right, left.dataType))) + case p @ BinaryComparison(left, right) + if findCommonTypeForBinaryComparison(left.dataType, right.dataType).isDefined => + val commonType = findCommonTypeForBinaryComparison(left.dataType, right.dataType).get + p.makeCopy(Array(castExpr(left, commonType), castExpr(right, commonType))) case Sum(e @ StringType()) => Sum(Cast(e, DoubleType)) case Average(e @ StringType()) => Average(Cast(e, DoubleType)) @@ -365,17 +368,72 @@ object TypeCoercion { } /** - * Convert the value and in list expressions to the common operator type - * by looking at all the argument types and finding the closest one that - * all the arguments can be cast to. When no common operator type is found - * the original expression will be returned and an Analysis Exception will - * be raised at type checking phase. + * Handles type coercion for both IN expression with subquery and IN + * expressions without subquery. + * 1. In the first case, find the common type by comparing the left hand side (LHS) + * expression types against corresponding right hand side (RHS) expression derived + * from the subquery expression's plan output. Inject appropriate casts in the + * LHS and RHS side of IN expression. + * + * 2. In the second case, convert the value and in list expressions to the + * common operator type by looking at all the argument types and finding + * the closest one that all the arguments can be cast to. When no common + * operator type is found the original expression will be returned and an + * Analysis Exception will be raised at the type checking phase. */ object InConversion extends Rule[LogicalPlan] { + private def flattenExpr(expr: Expression): Seq[Expression] = { + expr match { + // Multi columns in IN clause is represented as a CreateNamedStruct. + // flatten the named struct to get the list of expressions. + case cns: CreateNamedStruct => cns.valExprs + case expr => Seq(expr) + } + } + def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e + // Handle type casting required between value expression and subquery output + // in IN subquery. + case i @ In(a, Seq(ListQuery(sub, children, exprId))) + if !i.resolved && flattenExpr(a).length == sub.output.length => + // LHS is the value expression of IN subquery. + val lhs = flattenExpr(a) + + // RHS is the subquery output. + val rhs = sub.output + + val commonTypes = lhs.zip(rhs).flatMap { case (l, r) => + findCommonTypeForBinaryComparison(l.dataType, r.dataType) + .orElse(findTightestCommonType(l.dataType, r.dataType)) + } + + // The number of columns/expressions must match between LHS and RHS of an + // IN subquery expression. + if (commonTypes.length == lhs.length) { + val castedRhs = rhs.zip(commonTypes).map { + case (e, dt) if e.dataType != dt => Alias(Cast(e, dt), e.name)() + case (e, _) => e + } + val castedLhs = lhs.zip(commonTypes).map { + case (e, dt) if e.dataType != dt => Cast(e, dt) + case (e, _) => e + } + + // Before constructing the In expression, wrap the multi values in LHS + // in a CreatedNamedStruct. + val newLhs = castedLhs match { + case Seq(lhs) => lhs + case _ => CreateStruct(castedLhs) + } + + In(newLhs, Seq(ListQuery(Project(castedRhs, sub), children, exprId))) + } else { + i + } + case i @ In(a, b) if b.exists(_.dataType != a.dataType) => findWiderCommonType(i.children.map(_.dataType)) match { case Some(finalDataType) => i.withNewChildren(i.children.map(Cast(_, finalDataType))) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index ac56ff13fa5b..e5d1a1e2996c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -123,19 +123,44 @@ case class Not(child: Expression) */ @ExpressionDescription( usage = "expr1 _FUNC_(expr2, expr3, ...) - Returns true if `expr` equals to any valN.") -case class In(value: Expression, list: Seq[Expression]) extends Predicate - with ImplicitCastInputTypes { +case class In(value: Expression, list: Seq[Expression]) extends Predicate { require(list != null, "list should not be null") + override def checkInputDataTypes(): TypeCheckResult = { + list match { + case ListQuery(sub, _, _) :: Nil => + val valExprs = value match { + case cns: CreateNamedStruct => cns.valExprs + case expr => Seq(expr) + } - override def inputTypes: Seq[AbstractDataType] = value.dataType +: list.map(_.dataType) + val mismatchedColumns = valExprs.zip(sub.output).flatMap { + case (l, r) if l.dataType != r.dataType => + s"(${l.sql}:${l.dataType.catalogString}, ${r.sql}:${r.dataType.catalogString})" + case _ => None + } - override def checkInputDataTypes(): TypeCheckResult = { - if (list.exists(l => l.dataType != value.dataType)) { - TypeCheckResult.TypeCheckFailure( - "Arguments must be same type") - } else { - TypeCheckResult.TypeCheckSuccess + if (mismatchedColumns.nonEmpty) { + TypeCheckResult.TypeCheckFailure( + s""" + |The data type of one or more elements in the left hand side of an IN subquery + |is not compatible with the data type of the output of the subquery + |Mismatched columns: + |[${mismatchedColumns.mkString(", ")}] + |Left side: + |[${valExprs.map(_.dataType.catalogString).mkString(", ")}]. + |Right side: + |[${sub.output.map(_.dataType.catalogString).mkString(", ")}]. + """.stripMargin) + } else { + TypeCheckResult.TypeCheckSuccess + } + case _ => + if (list.exists(l => l.dataType != value.dataType)) { + TypeCheckResult.TypeCheckFailure("Arguments must be same type") + } else { + TypeCheckResult.TypeCheckSuccess + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala index e2e7d98e3345..ad11700fa28d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala @@ -17,8 +17,11 @@ package org.apache.spark.sql.catalyst.expressions +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans.QueryPlan -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan} import org.apache.spark.sql.types._ /** @@ -40,19 +43,184 @@ abstract class PlanExpression[T <: QueryPlan[_]] extends Expression { /** * A base interface for expressions that contain a [[LogicalPlan]]. */ -abstract class SubqueryExpression extends PlanExpression[LogicalPlan] { +abstract class SubqueryExpression( + plan: LogicalPlan, + children: Seq[Expression], + exprId: ExprId) extends PlanExpression[LogicalPlan] { + + override lazy val resolved: Boolean = childrenResolved && plan.resolved + override lazy val references: AttributeSet = + if (plan.resolved) super.references -- plan.outputSet else super.references override def withNewPlan(plan: LogicalPlan): SubqueryExpression + override def semanticEquals(o: Expression): Boolean = o match { + case p: SubqueryExpression => + this.getClass.getName.equals(p.getClass.getName) && plan.sameResult(p.plan) && + children.length == p.children.length && + children.zip(p.children).forall(p => p._1.semanticEquals(p._2)) + case _ => false + } } object SubqueryExpression { + /** + * Returns true when an expression contains an IN or EXISTS subquery and false otherwise. + */ + def hasInOrExistsSubquery(e: Expression): Boolean = { + e.find { + case _: ListQuery | _: Exists => true + case _ => false + }.isDefined + } + + /** + * Returns true when an expression contains a subquery that has outer reference(s). The outer + * reference attributes are kept as children of subquery expression by + * [[org.apache.spark.sql.catalyst.analysis.Analyzer.ResolveSubquery]] + */ def hasCorrelatedSubquery(e: Expression): Boolean = { e.find { - case e: SubqueryExpression if e.children.nonEmpty => true + case s: SubqueryExpression => s.children.nonEmpty case _ => false }.isDefined } } +object SubExprUtils extends PredicateHelper { + /** + * Returns true when an expression contains correlated predicates i.e outer references and + * returns false otherwise. + */ + def containsOuter(e: Expression): Boolean = { + e.find(_.isInstanceOf[OuterReference]).isDefined + } + + /** + * Returns whether there are any null-aware predicate subqueries inside Not. If not, we could + * turn the null-aware predicate into not-null-aware predicate. + */ + def hasNullAwarePredicateWithinNot(condition: Expression): Boolean = { + splitConjunctivePredicates(condition).exists { + case _: Exists | Not(_: Exists) | In(_, Seq(_: ListQuery)) | Not(In(_, Seq(_: ListQuery))) => + false + case e => e.find { x => + x.isInstanceOf[Not] && e.find { + case In(_, Seq(_: ListQuery)) => true + case _ => false + }.isDefined + }.isDefined + } + + } + + /** + * Returns an expression after removing the OuterReference shell. + */ + def stripOuterReference(e: Expression): Expression = e.transform { case OuterReference(r) => r } + + /** + * Returns the list of expressions after removing the OuterReference shell from each of + * the expression. + */ + def stripOuterReferences(e: Seq[Expression]): Seq[Expression] = e.map(stripOuterReference) + + /** + * Returns the logical plan after removing the OuterReference shell from all the expressions + * of the input logical plan. + */ + def stripOuterReferences(p: LogicalPlan): LogicalPlan = { + p.transformAllExpressions { + case OuterReference(a) => a + } + } + + /** + * Given a logical plan, returns TRUE if it has an outer reference and false otherwise. + */ + def hasOuterReferences(plan: LogicalPlan): Boolean = { + plan.find { + case f: Filter => containsOuter(f.condition) + case other => false + }.isDefined + } + + /** + * Given a list of expressions, returns the expressions which have outer references. Aggregate + * expressions are treated in a special way. If the children of aggregate expression contains an + * outer reference, then the entire aggregate expression is marked as an outer reference. + * Example (SQL): + * {{{ + * SELECT a FROM l GROUP by 1 HAVING EXISTS (SELECT 1 FROM r WHERE d < min(b)) + * }}} + * In the above case, we want to mark the entire min(b) as an outer reference + * OuterReference(min(b)) instead of min(OuterReference(b)). + * TODO: Currently we don't allow deep correlation. Also, we don't allow mixing of + * outer references and local references under an aggregate expression. + * For example (SQL): + * {{{ + * SELECT .. FROM p1 + * WHERE EXISTS (SELECT ... + * FROM p2 + * WHERE EXISTS (SELECT ... + * FROM sq + * WHERE min(p1.a + p2.b) = sq.c)) + * + * SELECT .. FROM p1 + * WHERE EXISTS (SELECT ... + * FROM p2 + * WHERE EXISTS (SELECT ... + * FROM sq + * WHERE min(p1.a) + max(p2.b) = sq.c)) + * + * SELECT .. FROM p1 + * WHERE EXISTS (SELECT ... + * FROM p2 + * WHERE EXISTS (SELECT ... + * FROM sq + * WHERE min(p1.a + sq.c) > 1)) + * }}} + * The code below needs to change when we support the above cases. + */ + def getOuterReferences(conditions: Seq[Expression]): Seq[Expression] = { + val outerExpressions = ArrayBuffer.empty[Expression] + conditions foreach { expr => + expr transformDown { + case a: AggregateExpression if a.collectLeaves.forall(_.isInstanceOf[OuterReference]) => + val newExpr = stripOuterReference(a) + outerExpressions += newExpr + newExpr + case OuterReference(e) => + outerExpressions += e + e + } + } + outerExpressions + } + + /** + * Returns all the expressions that have outer references from a logical plan. Currently only + * Filter operator can host outer references. + */ + def getOuterReferences(plan: LogicalPlan): Seq[Expression] = { + val conditions = plan.collect { case Filter(cond, _) => cond } + getOuterReferences(conditions) + } + + /** + * Returns the correlated predicates from a logical plan. The OuterReference wrapper + * is removed before returning the predicate to the caller. + */ + def getCorrelatedPredicates(plan: LogicalPlan): Seq[Expression] = { + val conditions = plan.collect { case Filter(cond, _) => cond } + conditions.flatMap { e => + val (correlated, _) = splitConjunctivePredicates(e).partition(containsOuter) + stripOuterReferences(correlated) match { + case Nil => None + case xs => xs + } + } + } +} + /** * A subquery that will return only one row and one column. This will be converted into a physical * scalar subquery during planning. @@ -63,14 +231,8 @@ case class ScalarSubquery( plan: LogicalPlan, children: Seq[Expression] = Seq.empty, exprId: ExprId = NamedExpression.newExprId) - extends SubqueryExpression with Unevaluable { - override lazy val resolved: Boolean = childrenResolved && plan.resolved - override lazy val references: AttributeSet = { - if (plan.resolved) super.references -- plan.outputSet - else super.references - } + extends SubqueryExpression(plan, children, exprId) with Unevaluable { override def dataType: DataType = plan.schema.fields.head.dataType - override def foldable: Boolean = false override def nullable: Boolean = true override def withNewPlan(plan: LogicalPlan): ScalarSubquery = copy(plan = plan) override def toString: String = s"scalar-subquery#${exprId.id} $conditionString" @@ -79,59 +241,12 @@ case class ScalarSubquery( object ScalarSubquery { def hasCorrelatedScalarSubquery(e: Expression): Boolean = { e.find { - case e: ScalarSubquery if e.children.nonEmpty => true + case s: ScalarSubquery => s.children.nonEmpty case _ => false }.isDefined } } -/** - * A predicate subquery checks the existence of a value in a sub-query. We currently only allow - * [[PredicateSubquery]] expressions within a Filter plan (i.e. WHERE or a HAVING clause). This will - * be rewritten into a left semi/anti join during analysis. - */ -case class PredicateSubquery( - plan: LogicalPlan, - children: Seq[Expression] = Seq.empty, - nullAware: Boolean = false, - exprId: ExprId = NamedExpression.newExprId) - extends SubqueryExpression with Predicate with Unevaluable { - override lazy val resolved = childrenResolved && plan.resolved - override lazy val references: AttributeSet = super.references -- plan.outputSet - override def nullable: Boolean = nullAware - override def withNewPlan(plan: LogicalPlan): PredicateSubquery = copy(plan = plan) - override def semanticEquals(o: Expression): Boolean = o match { - case p: PredicateSubquery => - plan.sameResult(p.plan) && nullAware == p.nullAware && - children.length == p.children.length && - children.zip(p.children).forall(p => p._1.semanticEquals(p._2)) - case _ => false - } - override def toString: String = s"predicate-subquery#${exprId.id} $conditionString" -} - -object PredicateSubquery { - def hasPredicateSubquery(e: Expression): Boolean = { - e.find { - case _: PredicateSubquery | _: ListQuery | _: Exists => true - case _ => false - }.isDefined - } - - /** - * Returns whether there are any null-aware predicate subqueries inside Not. If not, we could - * turn the null-aware predicate into not-null-aware predicate. - */ - def hasNullAwarePredicateWithinNot(e: Expression): Boolean = { - e.find{ x => - x.isInstanceOf[Not] && e.find { - case p: PredicateSubquery => p.nullAware - case _ => false - }.isDefined - }.isDefined - } -} - /** * A [[ListQuery]] expression defines the query which we want to search in an IN subquery * expression. It should and can only be used in conjunction with an IN expression. @@ -144,18 +259,20 @@ object PredicateSubquery { * FROM b) * }}} */ -case class ListQuery(plan: LogicalPlan, exprId: ExprId = NamedExpression.newExprId) - extends SubqueryExpression with Unevaluable { - override lazy val resolved = false - override def children: Seq[Expression] = Seq.empty - override def dataType: DataType = ArrayType(NullType) +case class ListQuery( + plan: LogicalPlan, + children: Seq[Expression] = Seq.empty, + exprId: ExprId = NamedExpression.newExprId) + extends SubqueryExpression(plan, children, exprId) with Unevaluable { + override def dataType: DataType = plan.schema.fields.head.dataType override def nullable: Boolean = false override def withNewPlan(plan: LogicalPlan): ListQuery = copy(plan = plan) - override def toString: String = s"list#${exprId.id}" + override def toString: String = s"list#${exprId.id} $conditionString" } /** * The [[Exists]] expression checks if a row exists in a subquery given some correlated condition. + * * For example (SQL): * {{{ * SELECT * @@ -165,11 +282,12 @@ case class ListQuery(plan: LogicalPlan, exprId: ExprId = NamedExpression.newExpr * WHERE b.id = a.id) * }}} */ -case class Exists(plan: LogicalPlan, exprId: ExprId = NamedExpression.newExprId) - extends SubqueryExpression with Predicate with Unevaluable { - override lazy val resolved = false - override def children: Seq[Expression] = Seq.empty +case class Exists( + plan: LogicalPlan, + children: Seq[Expression] = Seq.empty, + exprId: ExprId = NamedExpression.newExprId) + extends SubqueryExpression(plan, children, exprId) with Predicate with Unevaluable { override def nullable: Boolean = false override def withNewPlan(plan: LogicalPlan): Exists = copy(plan = plan) - override def toString: String = s"exists#${exprId.id}" + override def toString: String = s"exists#${exprId.id} $conditionString" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index caafa1c134cd..e9dbded3d4d0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -68,6 +68,8 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf) // since the other rules might make two separate Unions operators adjacent. Batch("Union", Once, CombineUnions) :: + Batch("Pullup Correlated Expressions", Once, + PullupCorrelatedPredicates) :: Batch("Subquery", Once, OptimizeSubqueries) :: Batch("Replace Operators", fixedPoint, @@ -885,7 +887,7 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper { private def canPushThroughCondition(plan: LogicalPlan, condition: Expression): Boolean = { val attributes = plan.outputSet val matched = condition.find { - case PredicateSubquery(p, _, _, _) => p.outputSet.intersect(attributes).nonEmpty + case s: SubqueryExpression => s.plan.outputSet.intersect(attributes).nonEmpty case _ => false } matched.isEmpty diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala index fb7ce6aecea5..ba3fd1d5f802 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala @@ -21,6 +21,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.expressions.SubExprUtils._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ @@ -41,10 +42,17 @@ import org.apache.spark.sql.types._ * condition. */ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { + private def getValueExpression(e: Expression): Seq[Expression] = { + e match { + case cns : CreateNamedStruct => cns.valExprs + case expr => Seq(expr) + } + } + def apply(plan: LogicalPlan): LogicalPlan = plan transform { case Filter(condition, child) => val (withSubquery, withoutSubquery) = - splitConjunctivePredicates(condition).partition(PredicateSubquery.hasPredicateSubquery) + splitConjunctivePredicates(condition).partition(SubqueryExpression.hasInOrExistsSubquery) // Construct the pruned filter condition. val newFilter: LogicalPlan = withoutSubquery match { @@ -54,20 +62,25 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { // Filter the plan by applying left semi and left anti joins. withSubquery.foldLeft(newFilter) { - case (p, PredicateSubquery(sub, conditions, _, _)) => + case (p, Exists(sub, conditions, _)) => val (joinCond, outerPlan) = rewriteExistentialExpr(conditions, p) Join(outerPlan, sub, LeftSemi, joinCond) - case (p, Not(PredicateSubquery(sub, conditions, false, _))) => + case (p, Not(Exists(sub, conditions, _))) => val (joinCond, outerPlan) = rewriteExistentialExpr(conditions, p) Join(outerPlan, sub, LeftAnti, joinCond) - case (p, Not(PredicateSubquery(sub, conditions, true, _))) => + case (p, In(value, Seq(ListQuery(sub, conditions, _)))) => + val inConditions = getValueExpression(value).zip(sub.output).map(EqualTo.tupled) + val (joinCond, outerPlan) = rewriteExistentialExpr(inConditions ++ conditions, p) + Join(outerPlan, sub, LeftSemi, joinCond) + case (p, Not(In(value, Seq(ListQuery(sub, conditions, _))))) => // This is a NULL-aware (left) anti join (NAAJ) e.g. col NOT IN expr // Construct the condition. A NULL in one of the conditions is regarded as a positive // result; such a row will be filtered out by the Anti-Join operator. // Note that will almost certainly be planned as a Broadcast Nested Loop join. // Use EXISTS if performance matters to you. - val (joinCond, outerPlan) = rewriteExistentialExpr(conditions, p) + val inConditions = getValueExpression(value).zip(sub.output).map(EqualTo.tupled) + val (joinCond, outerPlan) = rewriteExistentialExpr(inConditions ++ conditions, p) // Expand the NOT IN expression with the NULL-aware semantic // to its full form. That is from: // (a1,b1,...) = (a2,b2,...) @@ -83,11 +96,10 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { } /** - * Given a predicate expression and an input plan, it rewrites - * any embedded existential sub-query into an existential join. - * It returns the rewritten expression together with the updated plan. - * Currently, it does not support null-aware joins. Embedded NOT IN predicates - * are blocked in the Analyzer. + * Given a predicate expression and an input plan, it rewrites any embedded existential sub-query + * into an existential join. It returns the rewritten expression together with the updated plan. + * Currently, it does not support NOT IN nested inside a NOT expression. This case is blocked in + * the Analyzer. */ private def rewriteExistentialExpr( exprs: Seq[Expression], @@ -95,17 +107,138 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { var newPlan = plan val newExprs = exprs.map { e => e transformUp { - case PredicateSubquery(sub, conditions, nullAware, _) => - // TODO: support null-aware join + case Exists(sub, conditions, _) => val exists = AttributeReference("exists", BooleanType, nullable = false)() newPlan = Join(newPlan, sub, ExistenceJoin(exists), conditions.reduceLeftOption(And)) exists - } + case In(value, Seq(ListQuery(sub, conditions, _))) => + val exists = AttributeReference("exists", BooleanType, nullable = false)() + val inConditions = getValueExpression(value).zip(sub.output).map(EqualTo.tupled) + val newConditions = (inConditions ++ conditions).reduceLeftOption(And) + newPlan = Join(newPlan, sub, ExistenceJoin(exists), newConditions) + exists + } } (newExprs.reduceOption(And), newPlan) } } + /** + * Pull out all (outer) correlated predicates from a given subquery. This method removes the + * correlated predicates from subquery [[Filter]]s and adds the references of these predicates + * to all intermediate [[Project]] and [[Aggregate]] clauses (if they are missing) in order to + * be able to evaluate the predicates at the top level. + * + * TODO: Look to merge this rule with RewritePredicateSubquery. + */ +object PullupCorrelatedPredicates extends Rule[LogicalPlan] with PredicateHelper { + /** + * Returns the correlated predicates and a updated plan that removes the outer references. + */ + private def pullOutCorrelatedPredicates( + sub: LogicalPlan, + outer: Seq[LogicalPlan]): (LogicalPlan, Seq[Expression]) = { + val predicateMap = scala.collection.mutable.Map.empty[LogicalPlan, Seq[Expression]] + + /** Determine which correlated predicate references are missing from this plan. */ + def missingReferences(p: LogicalPlan): AttributeSet = { + val localPredicateReferences = p.collect(predicateMap) + .flatten + .map(_.references) + .reduceOption(_ ++ _) + .getOrElse(AttributeSet.empty) + localPredicateReferences -- p.outputSet + } + + // Simplify the predicates before pulling them out. + val transformed = BooleanSimplification(sub) transformUp { + case f @ Filter(cond, child) => + val (correlated, local) = + splitConjunctivePredicates(cond).partition(containsOuter) + + // Rewrite the filter without the correlated predicates if any. + correlated match { + case Nil => f + case xs if local.nonEmpty => + val newFilter = Filter(local.reduce(And), child) + predicateMap += newFilter -> xs + newFilter + case xs => + predicateMap += child -> xs + child + } + case p @ Project(expressions, child) => + val referencesToAdd = missingReferences(p) + if (referencesToAdd.nonEmpty) { + Project(expressions ++ referencesToAdd, child) + } else { + p + } + case a @ Aggregate(grouping, expressions, child) => + val referencesToAdd = missingReferences(a) + if (referencesToAdd.nonEmpty) { + Aggregate(grouping ++ referencesToAdd, expressions ++ referencesToAdd, child) + } else { + a + } + case p => + p + } + + // Make sure the inner and the outer query attributes do not collide. + // In case of a collision, change the subquery plan's output to use + // different attribute by creating alias(s). + val baseConditions = predicateMap.values.flatten.toSeq + val (newPlan, newCond) = if (outer.nonEmpty) { + val outputSet = outer.map(_.outputSet).reduce(_ ++ _) + val duplicates = transformed.outputSet.intersect(outputSet) + val (plan, deDuplicatedConditions) = if (duplicates.nonEmpty) { + val aliasMap = AttributeMap(duplicates.map { dup => + dup -> Alias(dup, dup.toString)() + }.toSeq) + val aliasedExpressions = transformed.output.map { ref => + aliasMap.getOrElse(ref, ref) + } + val aliasedProjection = Project(aliasedExpressions, transformed) + val aliasedConditions = baseConditions.map(_.transform { + case ref: Attribute => aliasMap.getOrElse(ref, ref).toAttribute + }) + (aliasedProjection, aliasedConditions) + } else { + (transformed, baseConditions) + } + (plan, stripOuterReferences(deDuplicatedConditions)) + } else { + (transformed, stripOuterReferences(baseConditions)) + } + (newPlan, newCond) + } + + private def rewriteSubQueries(plan: LogicalPlan, outerPlans: Seq[LogicalPlan]): LogicalPlan = { + plan transformExpressions { + case ScalarSubquery(sub, children, exprId) if children.nonEmpty => + val (newPlan, newCond) = pullOutCorrelatedPredicates(sub, outerPlans) + ScalarSubquery(newPlan, newCond, exprId) + case Exists(sub, children, exprId) if children.nonEmpty => + val (newPlan, newCond) = pullOutCorrelatedPredicates(sub, outerPlans) + Exists(newPlan, newCond, exprId) + case ListQuery(sub, _, exprId) => + val (newPlan, newCond) = pullOutCorrelatedPredicates(sub, outerPlans) + ListQuery(newPlan, newCond, exprId) + } + } + + /** + * Pull up the correlated predicates and rewrite all subqueries in an operator tree.. + */ + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + case f @ Filter(_, a: Aggregate) => + rewriteSubQueries(f, Seq(a, a.child)) + // Only a few unary nodes (Project/Filter/Aggregate) can contain subqueries. + case q: UnaryNode => + rewriteSubQueries(q, q.children) + } +} /** * This rule rewrites correlated [[ScalarSubquery]] expressions into LEFT OUTER joins. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index c5e877d12811..d2ebca5a83dd 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -530,7 +530,7 @@ class AnalysisErrorSuite extends AnalysisTest { Exists( Join( LocalRelation(b), - Filter(EqualTo(OuterReference(a), c), LocalRelation(c)), + Filter(EqualTo(UnresolvedAttribute("a"), c), LocalRelation(c)), LeftOuter, Option(EqualTo(b, c)))), LocalRelation(a)) @@ -539,7 +539,7 @@ class AnalysisErrorSuite extends AnalysisTest { val plan2 = Filter( Exists( Join( - Filter(EqualTo(OuterReference(a), c), LocalRelation(c)), + Filter(EqualTo(UnresolvedAttribute("a"), c), LocalRelation(c)), LocalRelation(b), RightOuter, Option(EqualTo(b, c)))), @@ -547,14 +547,15 @@ class AnalysisErrorSuite extends AnalysisTest { assertAnalysisError(plan2, "Accessing outer query column is not allowed in" :: Nil) val plan3 = Filter( - Exists(Union(LocalRelation(b), Filter(EqualTo(OuterReference(a), c), LocalRelation(c)))), + Exists(Union(LocalRelation(b), + Filter(EqualTo(UnresolvedAttribute("a"), c), LocalRelation(c)))), LocalRelation(a)) assertAnalysisError(plan3, "Accessing outer query column is not allowed in" :: Nil) val plan4 = Filter( Exists( Limit(1, - Filter(EqualTo(OuterReference(a), b), LocalRelation(b))) + Filter(EqualTo(UnresolvedAttribute("a"), b), LocalRelation(b))) ), LocalRelation(a)) assertAnalysisError(plan4, "Accessing outer query column is not allowed in" :: Nil) @@ -562,7 +563,7 @@ class AnalysisErrorSuite extends AnalysisTest { val plan5 = Filter( Exists( Sample(0.0, 0.5, false, 1L, - Filter(EqualTo(OuterReference(a), b), LocalRelation(b)))().select('b) + Filter(EqualTo(UnresolvedAttribute("a"), b), LocalRelation(b)))().select('b) ), LocalRelation(a)) assertAnalysisError(plan5, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveSubquerySuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveSubquerySuite.scala index 4aafb2b83fb6..55693121431a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveSubquerySuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveSubquerySuite.scala @@ -33,7 +33,7 @@ class ResolveSubquerySuite extends AnalysisTest { val t2 = LocalRelation(b) test("SPARK-17251 Improve `OuterReference` to be `NamedExpression`") { - val expr = Filter(In(a, Seq(ListQuery(Project(Seq(OuterReference(a)), t2)))), t1) + val expr = Filter(In(a, Seq(ListQuery(Project(Seq(UnresolvedAttribute("a")), t2)))), t1) val m = intercept[AnalysisException] { SimpleAnalyzer.ResolveSubquery(expr) }.getMessage diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala index e9b7a0c6ad67..5eb31413ad70 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala @@ -43,8 +43,6 @@ abstract class PlanTest extends SparkFunSuite with PredicateHelper { e.copy(exprId = ExprId(0)) case l: ListQuery => l.copy(exprId = ExprId(0)) - case p: PredicateSubquery => - p.copy(exprId = ExprId(0)) case a: AttributeReference => AttributeReference(a.name, a.dataType, a.nullable)(exprId = ExprId(0)) case a: Alias => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala index 730ca27f82ba..58be2d1da281 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala @@ -144,9 +144,6 @@ case class PlanSubqueries(sparkSession: SparkSession) extends Rule[SparkPlan] { ScalarSubquery( SubqueryExec(s"subquery${subquery.exprId.id}", executedPlan), subquery.exprId) - case expressions.PredicateSubquery(query, Seq(e: Expression), _, exprId) => - val executedPlan = new QueryExecution(sparkSession, query).executedPlan - InSubquery(e, SubqueryExec(s"subquery${exprId.id}", executedPlan), exprId) } } } diff --git a/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/invalid-correlation.sql.out b/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/invalid-correlation.sql.out index 50ae01e181bc..f7bbb35aad6c 100644 --- a/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/invalid-correlation.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/invalid-correlation.sql.out @@ -46,7 +46,7 @@ and t2b = (select max(avg) struct<> -- !query 3 output org.apache.spark.sql.AnalysisException -expression 't2.`t2b`' is neither present in the group by, nor is it an aggregate function. Add to group by or wrap in first() (or first_value) if you don't care which value you get.; +grouping expressions sequence is empty, and 't2.`t2b`' is not an aggregate function. Wrap '(avg(CAST(t2.`t2b` AS BIGINT)) AS `avg`)' in windowing function(s) or wrap 't2.`t2b`' in first() (or first_value) if you don't care which value you get.; -- !query 4 @@ -63,4 +63,4 @@ where t1a in (select min(t2a) struct<> -- !query 4 output org.apache.spark.sql.AnalysisException -resolved attribute(s) t2b#x missing from min(t2a)#x,t2c#x in operator !Filter predicate-subquery#x [(t2c#x = max(t3c)#x) && (t3b#x > t2b#x)]; +resolved attribute(s) t2b#x missing from min(t2a)#x,t2c#x in operator !Filter t2c#x IN (list#x [t2b#x]); diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala index 25dbecb5894e..6f1cd49c08ee 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala @@ -622,7 +622,12 @@ class SubquerySuite extends QueryTest with SharedSQLContext { test("SPARK-15370: COUNT bug with attribute ref in subquery input and output ") { checkAnswer( - sql("select l.b, (select (r.c + count(*)) is null from r where l.a = r.c) from l"), + sql( + """ + |select l.b, (select (r.c + count(*)) is null + |from r + |where l.a = r.c group by r.c) from l + """.stripMargin), Row(1.0, false) :: Row(1.0, false) :: Row(2.0, true) :: Row(2.0, true) :: Row(3.0, false) :: Row(5.0, true) :: Row(null, false) :: Row(null, true) :: Nil) } From 0ee38a39e43dd7ad9d50457e446ae36f64621a1b Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 14 Mar 2017 19:02:30 +0800 Subject: [PATCH 022/512] [SPARK-19944][SQL] Move SQLConf from sql/core to sql/catalyst ## What changes were proposed in this pull request? This patch moves SQLConf from sql/core to sql/catalyst. To minimize the changes, the patch used type alias to still keep CatalystConf (as a type alias) and SimpleCatalystConf (as a concrete class that extends SQLConf). Motivation for the change is that it is pretty weird to have SQLConf only in sql/core and then we have to duplicate config options that impact optimizer/analyzer in sql/catalyst using CatalystConf. ## How was this patch tested? N/A Author: Reynold Xin Closes #17285 from rxin/SPARK-19944. --- .../spark/sql/catalyst/CatalystConf.scala | 93 --------------- .../sql/catalyst/SimpleCatalystConf.scala | 48 ++++++++ .../apache/spark/sql/catalyst/package.scala | 7 ++ .../apache/spark/sql/internal/SQLConf.scala | 106 +++++------------- .../spark/sql/internal/StaticSQLConf.scala | 84 ++++++++++++++ 5 files changed, 165 insertions(+), 173 deletions(-) delete mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SimpleCatalystConf.scala rename sql/{core => catalyst}/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala (91%) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala deleted file mode 100644 index cff0efa97993..000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala +++ /dev/null @@ -1,93 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst - -import java.util.TimeZone - -import org.apache.spark.sql.catalyst.analysis._ - -/** - * Interface for configuration options used in the catalyst module. - */ -trait CatalystConf { - def caseSensitiveAnalysis: Boolean - - def orderByOrdinal: Boolean - def groupByOrdinal: Boolean - - def optimizerMaxIterations: Int - def optimizerInSetConversionThreshold: Int - def maxCaseBranchesForCodegen: Int - - def tableRelationCacheSize: Int - - def runSQLonFile: Boolean - - def warehousePath: String - - def sessionLocalTimeZone: String - - /** If true, cartesian products between relations will be allowed for all - * join types(inner, (left|right|full) outer). - * If false, cartesian products will require explicit CROSS JOIN syntax. - */ - def crossJoinEnabled: Boolean - - /** - * Returns the [[Resolver]] for the current configuration, which can be used to determine if two - * identifiers are equal. - */ - def resolver: Resolver = { - if (caseSensitiveAnalysis) caseSensitiveResolution else caseInsensitiveResolution - } - - /** - * Enables CBO for estimation of plan statistics when set true. - */ - def cboEnabled: Boolean - - /** Enables join reorder in CBO. */ - def joinReorderEnabled: Boolean - - /** The maximum number of joined nodes allowed in the dynamic programming algorithm. */ - def joinReorderDPThreshold: Int - - override def clone(): CatalystConf = throw new CloneNotSupportedException() -} - - -/** A CatalystConf that can be used for local testing. */ -case class SimpleCatalystConf( - caseSensitiveAnalysis: Boolean, - orderByOrdinal: Boolean = true, - groupByOrdinal: Boolean = true, - optimizerMaxIterations: Int = 100, - optimizerInSetConversionThreshold: Int = 10, - maxCaseBranchesForCodegen: Int = 20, - tableRelationCacheSize: Int = 1000, - runSQLonFile: Boolean = true, - crossJoinEnabled: Boolean = false, - cboEnabled: Boolean = false, - joinReorderEnabled: Boolean = false, - joinReorderDPThreshold: Int = 12, - warehousePath: String = "/user/hive/warehouse", - sessionLocalTimeZone: String = TimeZone.getDefault().getID) - extends CatalystConf { - - override def clone(): SimpleCatalystConf = this.copy() -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SimpleCatalystConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SimpleCatalystConf.scala new file mode 100644 index 000000000000..746f84459de2 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SimpleCatalystConf.scala @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst + +import java.util.TimeZone + +import org.apache.spark.sql.internal.SQLConf + + +/** + * A SQLConf that can be used for local testing. This class is only here to minimize the change + * for ticket SPARK-19944 (moves SQLConf from sql/core to sql/catalyst). This class should + * eventually be removed (test cases should just create SQLConf and set values appropriately). + */ +case class SimpleCatalystConf( + override val caseSensitiveAnalysis: Boolean, + override val orderByOrdinal: Boolean = true, + override val groupByOrdinal: Boolean = true, + override val optimizerMaxIterations: Int = 100, + override val optimizerInSetConversionThreshold: Int = 10, + override val maxCaseBranchesForCodegen: Int = 20, + override val tableRelationCacheSize: Int = 1000, + override val runSQLonFile: Boolean = true, + override val crossJoinEnabled: Boolean = false, + override val cboEnabled: Boolean = false, + override val joinReorderEnabled: Boolean = false, + override val joinReorderDPThreshold: Int = 12, + override val warehousePath: String = "/user/hive/warehouse", + override val sessionLocalTimeZone: String = TimeZone.getDefault().getID) + extends SQLConf { + + override def clone(): SimpleCatalystConf = this.copy() +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/package.scala index 105cdf52500c..4af56afebb76 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/package.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql +import org.apache.spark.sql.internal.SQLConf + /** * Catalyst is a library for manipulating relational query plans. All classes in catalyst are * considered an internal API to Spark SQL and are subject to change between minor releases. @@ -29,4 +31,9 @@ package object catalyst { */ protected[sql] object ScalaReflectionLock + /** + * This class is only here to minimize the change for ticket SPARK-19944 + * (moves SQLConf from sql/core to sql/catalyst). This class should eventually be removed. + */ + type CatalystConf = SQLConf } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala similarity index 91% rename from sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 8e3f567b7dd9..315bedb12e71 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -24,15 +24,11 @@ import scala.collection.JavaConverters._ import scala.collection.immutable import org.apache.hadoop.fs.Path -import org.apache.parquet.hadoop.ParquetOutputCommitter import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ import org.apache.spark.network.util.ByteUnit -import org.apache.spark.sql.catalyst.CatalystConf -import org.apache.spark.sql.execution.datasources.SQLHadoopMapReduceCommitProtocol -import org.apache.spark.sql.execution.streaming.ManifestFileCommitProtocol -import org.apache.spark.util.Utils +import org.apache.spark.sql.catalyst.analysis.Resolver //////////////////////////////////////////////////////////////////////////////////////////////////// // This file defines the configuration options for Spark SQL. @@ -251,7 +247,7 @@ object SQLConf { "of org.apache.parquet.hadoop.ParquetOutputCommitter.") .internal() .stringConf - .createWithDefault(classOf[ParquetOutputCommitter].getName) + .createWithDefault("org.apache.parquet.hadoop.ParquetOutputCommitter") val PARQUET_VECTORIZED_READER_ENABLED = buildConf("spark.sql.parquet.enableVectorizedReader") @@ -417,7 +413,8 @@ object SQLConf { buildConf("spark.sql.sources.commitProtocolClass") .internal() .stringConf - .createWithDefault(classOf[SQLHadoopMapReduceCommitProtocol].getName) + .createWithDefault( + "org.apache.spark.sql.execution.datasources.SQLHadoopMapReduceCommitProtocol") val PARALLEL_PARTITION_DISCOVERY_THRESHOLD = buildConf("spark.sql.sources.parallelPartitionDiscovery.threshold") @@ -578,7 +575,7 @@ object SQLConf { buildConf("spark.sql.streaming.commitProtocolClass") .internal() .stringConf - .createWithDefault(classOf[ManifestFileCommitProtocol].getName) + .createWithDefault("org.apache.spark.sql.execution.streaming.ManifestFileCommitProtocol") val OBJECT_AGG_SORT_BASED_FALLBACK_THRESHOLD = buildConf("spark.sql.objectHashAggregate.sortBased.fallbackThreshold") @@ -723,7 +720,7 @@ object SQLConf { * * SQLConf is thread-safe (internally synchronized, so safe to be used in multiple threads). */ -private[sql] class SQLConf extends Serializable with CatalystConf with Logging { +class SQLConf extends Serializable with Logging { import SQLConf._ /** Only low degree of contention is expected for conf, thus NOT using ConcurrentHashMap. */ @@ -833,6 +830,18 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging { def caseSensitiveAnalysis: Boolean = getConf(SQLConf.CASE_SENSITIVE) + /** + * Returns the [[Resolver]] for the current configuration, which can be used to determine if two + * identifiers are equal. + */ + def resolver: Resolver = { + if (caseSensitiveAnalysis) { + org.apache.spark.sql.catalyst.analysis.caseSensitiveResolution + } else { + org.apache.spark.sql.catalyst.analysis.caseInsensitiveResolution + } + } + def subexpressionEliminationEnabled: Boolean = getConf(SUBEXPRESSION_ELIMINATION_ENABLED) @@ -890,7 +899,7 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging { def dataFramePivotMaxValues: Int = getConf(DATAFRAME_PIVOT_MAX_VALUES) - override def runSQLonFile: Boolean = getConf(RUN_SQL_ON_FILES) + def runSQLonFile: Boolean = getConf(RUN_SQL_ON_FILES) def enableTwoLevelAggMap: Boolean = getConf(ENABLE_TWOLEVEL_AGG_MAP) @@ -907,21 +916,21 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging { def hiveThriftServerSingleSession: Boolean = getConf(StaticSQLConf.HIVE_THRIFT_SERVER_SINGLESESSION) - override def orderByOrdinal: Boolean = getConf(ORDER_BY_ORDINAL) + def orderByOrdinal: Boolean = getConf(ORDER_BY_ORDINAL) - override def groupByOrdinal: Boolean = getConf(GROUP_BY_ORDINAL) + def groupByOrdinal: Boolean = getConf(GROUP_BY_ORDINAL) - override def crossJoinEnabled: Boolean = getConf(SQLConf.CROSS_JOINS_ENABLED) + def crossJoinEnabled: Boolean = getConf(SQLConf.CROSS_JOINS_ENABLED) - override def sessionLocalTimeZone: String = getConf(SQLConf.SESSION_LOCAL_TIMEZONE) + def sessionLocalTimeZone: String = getConf(SQLConf.SESSION_LOCAL_TIMEZONE) def ndvMaxError: Double = getConf(NDV_MAX_ERROR) - override def cboEnabled: Boolean = getConf(SQLConf.CBO_ENABLED) + def cboEnabled: Boolean = getConf(SQLConf.CBO_ENABLED) - override def joinReorderEnabled: Boolean = getConf(SQLConf.JOIN_REORDER_ENABLED) + def joinReorderEnabled: Boolean = getConf(SQLConf.JOIN_REORDER_ENABLED) - override def joinReorderDPThreshold: Int = getConf(SQLConf.JOIN_REORDER_DP_THRESHOLD) + def joinReorderDPThreshold: Int = getConf(SQLConf.JOIN_REORDER_DP_THRESHOLD) /** ********************** SQLConf functionality methods ************ */ @@ -1050,66 +1059,3 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging { result } } - -/** - * Static SQL configuration is a cross-session, immutable Spark configuration. External users can - * see the static sql configs via `SparkSession.conf`, but can NOT set/unset them. - */ -object StaticSQLConf { - - import SQLConf.buildStaticConf - - val WAREHOUSE_PATH = buildStaticConf("spark.sql.warehouse.dir") - .doc("The default location for managed databases and tables.") - .stringConf - .createWithDefault(Utils.resolveURI("spark-warehouse").toString) - - val CATALOG_IMPLEMENTATION = buildStaticConf("spark.sql.catalogImplementation") - .internal() - .stringConf - .checkValues(Set("hive", "in-memory")) - .createWithDefault("in-memory") - - val GLOBAL_TEMP_DATABASE = buildStaticConf("spark.sql.globalTempDatabase") - .internal() - .stringConf - .createWithDefault("global_temp") - - // This is used to control when we will split a schema's JSON string to multiple pieces - // in order to fit the JSON string in metastore's table property (by default, the value has - // a length restriction of 4000 characters, so do not use a value larger than 4000 as the default - // value of this property). We will split the JSON string of a schema to its length exceeds the - // threshold. Note that, this conf is only read in HiveExternalCatalog which is cross-session, - // that's why this conf has to be a static SQL conf. - val SCHEMA_STRING_LENGTH_THRESHOLD = - buildStaticConf("spark.sql.sources.schemaStringLengthThreshold") - .doc("The maximum length allowed in a single cell when " + - "storing additional schema information in Hive's metastore.") - .internal() - .intConf - .createWithDefault(4000) - - val FILESOURCE_TABLE_RELATION_CACHE_SIZE = - buildStaticConf("spark.sql.filesourceTableRelationCacheSize") - .internal() - .doc("The maximum size of the cache that maps qualified table names to table relation plans.") - .intConf - .checkValue(cacheSize => cacheSize >= 0, "The maximum size of the cache must not be negative") - .createWithDefault(1000) - - // When enabling the debug, Spark SQL internal table properties are not filtered out; however, - // some related DDL commands (e.g., ANALYZE TABLE and CREATE TABLE LIKE) might not work properly. - val DEBUG_MODE = buildStaticConf("spark.sql.debug") - .internal() - .doc("Only used for internal debugging. Not all functions are supported when it is enabled.") - .booleanConf - .createWithDefault(false) - - val HIVE_THRIFT_SERVER_SINGLESESSION = - buildStaticConf("spark.sql.hive.thriftServer.singleSession") - .doc("When set to true, Hive Thrift server is running in a single session mode. " + - "All the JDBC/ODBC connections share the temporary views, function registries, " + - "SQL configuration and the current database.") - .booleanConf - .createWithDefault(false) -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala new file mode 100644 index 000000000000..af1a9cee2962 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.internal + +import org.apache.spark.util.Utils + + +/** + * Static SQL configuration is a cross-session, immutable Spark configuration. External users can + * see the static sql configs via `SparkSession.conf`, but can NOT set/unset them. + */ +object StaticSQLConf { + + import SQLConf.buildStaticConf + + val WAREHOUSE_PATH = buildStaticConf("spark.sql.warehouse.dir") + .doc("The default location for managed databases and tables.") + .stringConf + .createWithDefault(Utils.resolveURI("spark-warehouse").toString) + + val CATALOG_IMPLEMENTATION = buildStaticConf("spark.sql.catalogImplementation") + .internal() + .stringConf + .checkValues(Set("hive", "in-memory")) + .createWithDefault("in-memory") + + val GLOBAL_TEMP_DATABASE = buildStaticConf("spark.sql.globalTempDatabase") + .internal() + .stringConf + .createWithDefault("global_temp") + + // This is used to control when we will split a schema's JSON string to multiple pieces + // in order to fit the JSON string in metastore's table property (by default, the value has + // a length restriction of 4000 characters, so do not use a value larger than 4000 as the default + // value of this property). We will split the JSON string of a schema to its length exceeds the + // threshold. Note that, this conf is only read in HiveExternalCatalog which is cross-session, + // that's why this conf has to be a static SQL conf. + val SCHEMA_STRING_LENGTH_THRESHOLD = + buildStaticConf("spark.sql.sources.schemaStringLengthThreshold") + .doc("The maximum length allowed in a single cell when " + + "storing additional schema information in Hive's metastore.") + .internal() + .intConf + .createWithDefault(4000) + + val FILESOURCE_TABLE_RELATION_CACHE_SIZE = + buildStaticConf("spark.sql.filesourceTableRelationCacheSize") + .internal() + .doc("The maximum size of the cache that maps qualified table names to table relation plans.") + .intConf + .checkValue(cacheSize => cacheSize >= 0, "The maximum size of the cache must not be negative") + .createWithDefault(1000) + + // When enabling the debug, Spark SQL internal table properties are not filtered out; however, + // some related DDL commands (e.g., ANALYZE TABLE and CREATE TABLE LIKE) might not work properly. + val DEBUG_MODE = buildStaticConf("spark.sql.debug") + .internal() + .doc("Only used for internal debugging. Not all functions are supported when it is enabled.") + .booleanConf + .createWithDefault(false) + + val HIVE_THRIFT_SERVER_SINGLESESSION = + buildStaticConf("spark.sql.hive.thriftServer.singleSession") + .doc("When set to true, Hive Thrift server is running in a single session mode. " + + "All the JDBC/ODBC connections share the temporary views, function registries, " + + "SQL configuration and the current database.") + .booleanConf + .createWithDefault(false) +} From a0b92f73fed9b91883f08cced1c09724e09e1883 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Tue, 14 Mar 2017 12:49:30 +0100 Subject: [PATCH 023/512] [SPARK-19850][SQL] Allow the use of aliases in SQL function calls ## What changes were proposed in this pull request? We currently cannot use aliases in SQL function calls. This is inconvenient when you try to create a struct. This SQL query for example `select struct(1, 2) st`, will create a struct with column names `col1` and `col2`. This is even more problematic when we want to append a field to an existing struct. For example if we want to a field to struct `st` we would issue the following SQL query `select struct(st.*, 1) as st from src`, the result will be struct `st` with an a column with a non descriptive name `col3` (if `st` itself has 2 fields). This PR proposes to change this by allowing the use of aliased expression in function parameters. For example `select struct(1 as a, 2 as b) st`, will create a struct with columns `a` & `b`. ## How was this patch tested? Added a test to `ExpressionParserSuite` and added a test file for `SQLQueryTestSuite`. Author: Herman van Hovell Closes #17245 from hvanhovell/SPARK-19850. --- .../spark/sql/catalyst/parser/SqlBase.g4 | 7 ++- .../sql/catalyst/parser/AstBuilder.scala | 4 +- .../parser/ExpressionParserSuite.scala | 2 + .../resources/sql-tests/inputs/struct.sql | 20 +++++++ .../sql-tests/results/struct.sql.out | 60 +++++++++++++++++++ 5 files changed, 88 insertions(+), 5 deletions(-) create mode 100644 sql/core/src/test/resources/sql-tests/inputs/struct.sql create mode 100644 sql/core/src/test/resources/sql-tests/results/struct.sql.out diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 59f93b3c469d..cc3b8fd3b468 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -506,10 +506,10 @@ expression booleanExpression : NOT booleanExpression #logicalNot + | EXISTS '(' query ')' #exists | predicated #booleanDefault | left=booleanExpression operator=AND right=booleanExpression #logicalBinary | left=booleanExpression operator=OR right=booleanExpression #logicalBinary - | EXISTS '(' query ')' #exists ; // workaround for: @@ -546,9 +546,10 @@ primaryExpression | constant #constantDefault | ASTERISK #star | qualifiedName '.' ASTERISK #star - | '(' expression (',' expression)+ ')' #rowConstructor + | '(' namedExpression (',' namedExpression)+ ')' #rowConstructor | '(' query ')' #subqueryExpression - | qualifiedName '(' (setQuantifier? expression (',' expression)*)? ')' (OVER windowSpec)? #functionCall + | qualifiedName '(' (setQuantifier? namedExpression (',' namedExpression)*)? ')' + (OVER windowSpec)? #functionCall | value=primaryExpression '[' index=valueExpression ']' #subscript | identifier #columnReference | base=primaryExpression '.' fieldName=identifier #dereference diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 3cf11adc1953..4c9fb2ec2774 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -1016,7 +1016,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { // Create the function call. val name = ctx.qualifiedName.getText val isDistinct = Option(ctx.setQuantifier()).exists(_.DISTINCT != null) - val arguments = ctx.expression().asScala.map(expression) match { + val arguments = ctx.namedExpression().asScala.map(expression) match { case Seq(UnresolvedStar(None)) if name.toLowerCase == "count" && !isDistinct => // Transform COUNT(*) into COUNT(1). Seq(Literal(1)) @@ -1127,7 +1127,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { * Create a [[CreateStruct]] expression. */ override def visitRowConstructor(ctx: RowConstructorContext): Expression = withOrigin(ctx) { - CreateStruct(ctx.expression.asScala.map(expression)) + CreateStruct(ctx.namedExpression().asScala.map(expression)) } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala index 2fecb8dc4a60..c2e62e739776 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala @@ -209,6 +209,7 @@ class ExpressionParserSuite extends PlanTest { assertEqual("foo(distinct a, b)", 'foo.distinctFunction('a, 'b)) assertEqual("grouping(distinct a, b)", 'grouping.distinctFunction('a, 'b)) assertEqual("`select`(all a, b)", 'select.function('a, 'b)) + assertEqual("foo(a as x, b as e)", 'foo.function('a as 'x, 'b as 'e)) } test("window function expressions") { @@ -278,6 +279,7 @@ class ExpressionParserSuite extends PlanTest { // Note that '(a)' will be interpreted as a nested expression. assertEqual("(a, b)", CreateStruct(Seq('a, 'b))) assertEqual("(a, b, c)", CreateStruct(Seq('a, 'b, 'c))) + assertEqual("(a as b, b as c)", CreateStruct(Seq('a as 'b, 'b as 'c))) } test("scalar sub-query") { diff --git a/sql/core/src/test/resources/sql-tests/inputs/struct.sql b/sql/core/src/test/resources/sql-tests/inputs/struct.sql new file mode 100644 index 000000000000..e56344dc4de8 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/struct.sql @@ -0,0 +1,20 @@ +CREATE TEMPORARY VIEW tbl_x AS VALUES + (1, NAMED_STRUCT('C', 'gamma', 'D', 'delta')), + (2, NAMED_STRUCT('C', 'epsilon', 'D', 'eta')), + (3, NAMED_STRUCT('C', 'theta', 'D', 'iota')) + AS T(ID, ST); + +-- Create a struct +SELECT STRUCT('alpha', 'beta') ST; + +-- Create a struct with aliases +SELECT STRUCT('alpha' AS A, 'beta' AS B) ST; + +-- Star expansion in a struct. +SELECT ID, STRUCT(ST.*) NST FROM tbl_x; + +-- Append a column to a struct +SELECT ID, STRUCT(ST.*,CAST(ID AS STRING) AS E) NST FROM tbl_x; + +-- Prepend a column to a struct +SELECT ID, STRUCT(CAST(ID AS STRING) AS AA, ST.*) NST FROM tbl_x; diff --git a/sql/core/src/test/resources/sql-tests/results/struct.sql.out b/sql/core/src/test/resources/sql-tests/results/struct.sql.out new file mode 100644 index 000000000000..3e32f4619546 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/struct.sql.out @@ -0,0 +1,60 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 6 + + +-- !query 0 +CREATE TEMPORARY VIEW tbl_x AS VALUES + (1, NAMED_STRUCT('C', 'gamma', 'D', 'delta')), + (2, NAMED_STRUCT('C', 'epsilon', 'D', 'eta')), + (3, NAMED_STRUCT('C', 'theta', 'D', 'iota')) + AS T(ID, ST) +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +SELECT STRUCT('alpha', 'beta') ST +-- !query 1 schema +struct> +-- !query 1 output +{"col1":"alpha","col2":"beta"} + + +-- !query 2 +SELECT STRUCT('alpha' AS A, 'beta' AS B) ST +-- !query 2 schema +struct> +-- !query 2 output +{"A":"alpha","B":"beta"} + + +-- !query 3 +SELECT ID, STRUCT(ST.*) NST FROM tbl_x +-- !query 3 schema +struct> +-- !query 3 output +1 {"C":"gamma","D":"delta"} +2 {"C":"epsilon","D":"eta"} +3 {"C":"theta","D":"iota"} + + +-- !query 4 +SELECT ID, STRUCT(ST.*,CAST(ID AS STRING) AS E) NST FROM tbl_x +-- !query 4 schema +struct> +-- !query 4 output +1 {"C":"gamma","D":"delta","E":"1"} +2 {"C":"epsilon","D":"eta","E":"2"} +3 {"C":"theta","D":"iota","E":"3"} + + +-- !query 5 +SELECT ID, STRUCT(CAST(ID AS STRING) AS AA, ST.*) NST FROM tbl_x +-- !query 5 schema +struct> +-- !query 5 output +1 {"AA":"1","C":"gamma","D":"delta"} +2 {"AA":"2","C":"epsilon","D":"eta"} +3 {"AA":"3","C":"theta","D":"iota"} From 1c7275efa7bfaaa92719750e93a7b35cbcb48e45 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Tue, 14 Mar 2017 14:02:48 +0100 Subject: [PATCH 024/512] [SPARK-18874][SQL] Fix 2.10 build after moving the subquery rules to optimization ## What changes were proposed in this pull request? Commit https://github.com/apache/spark/commit/4ce970d71488c7de6025ef925f75b8b92a5a6a79 in accidentally broke the 2.10 build for Spark. This PR fixes this by simplifying the offending pattern match. ## How was this patch tested? Existing tests. Author: Herman van Hovell Closes #17288 from hvanhovell/SPARK-18874. --- .../org/apache/spark/sql/catalyst/expressions/subquery.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala index ad11700fa28d..59db28d58afc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala @@ -100,8 +100,8 @@ object SubExprUtils extends PredicateHelper { */ def hasNullAwarePredicateWithinNot(condition: Expression): Boolean = { splitConjunctivePredicates(condition).exists { - case _: Exists | Not(_: Exists) | In(_, Seq(_: ListQuery)) | Not(In(_, Seq(_: ListQuery))) => - false + case _: Exists | Not(_: Exists) => false + case In(_, Seq(_: ListQuery)) | Not(In(_, Seq(_: ListQuery))) => false case e => e.find { x => x.isInstanceOf[Not] && e.find { case In(_, Seq(_: ListQuery)) => true From 5e96a57b2f383d4b33735681b41cd3ec06570671 Mon Sep 17 00:00:00 2001 From: Asher Krim Date: Tue, 14 Mar 2017 13:08:11 +0000 Subject: [PATCH 025/512] [SPARK-19922][ML] small speedups to findSynonyms Currently generating synonyms using a large model (I've tested with 3m words) is very slow. These efficiencies have sped things up for us by ~17% I wasn't sure if such small changes were worthy of a jira, but the guidelines seemed to suggest that that is the preferred approach ## What changes were proposed in this pull request? Address a few small issues in the findSynonyms logic: 1) remove usage of ``Array.fill`` to zero out the ``cosineVec`` array. The default float value in Scala and Java is 0.0f, so explicitly setting the values to zero is not needed 2) use Floats throughout. The conversion to Doubles before doing the ``priorityQueue`` is totally superfluous, since all the similarity computations are done using Floats anyway. Creating a second large array just serves to put extra strain on the GC 3) convert the slow ``for(i <- cosVec.indices)`` to an ugly, but faster, ``while`` loop These efficiencies are really only apparent when working with a large model ## How was this patch tested? Existing unit tests + some in-house tests to time the difference cc jkbradley MLNick srowen Author: Asher Krim Author: Asher Krim Closes #17263 from Krimit/fasterFindSynonyms. --- .../apache/spark/mllib/feature/Word2Vec.scala | 34 +++++++++++-------- 1 file changed, 19 insertions(+), 15 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index 531c8b07910f..6f96813497b6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -491,8 +491,8 @@ class Word2VecModel private[spark] ( // wordVecNorms: Array of length numWords, each value being the Euclidean norm // of the wordVector. - private val wordVecNorms: Array[Double] = { - val wordVecNorms = new Array[Double](numWords) + private val wordVecNorms: Array[Float] = { + val wordVecNorms = new Array[Float](numWords) var i = 0 while (i < numWords) { val vec = wordVectors.slice(i * vectorSize, i * vectorSize + vectorSize) @@ -570,7 +570,7 @@ class Word2VecModel private[spark] ( require(num > 0, "Number of similar words should > 0") val fVector = vector.toArray.map(_.toFloat) - val cosineVec = Array.fill[Float](numWords)(0) + val cosineVec = new Array[Float](numWords) val alpha: Float = 1 val beta: Float = 0 // Normalize input vector before blas.sgemv to avoid Inf value @@ -581,22 +581,23 @@ class Word2VecModel private[spark] ( blas.sgemv( "T", vectorSize, numWords, alpha, wordVectors, vectorSize, fVector, 1, beta, cosineVec, 1) - val cosVec = cosineVec.map(_.toDouble) - var ind = 0 - while (ind < numWords) { - val norm = wordVecNorms(ind) - if (norm == 0.0) { - cosVec(ind) = 0.0 + var i = 0 + while (i < numWords) { + val norm = wordVecNorms(i) + if (norm == 0.0f) { + cosineVec(i) = 0.0f } else { - cosVec(ind) /= norm + cosineVec(i) /= norm } - ind += 1 + i += 1 } - val pq = new BoundedPriorityQueue[(String, Double)](num + 1)(Ordering.by(_._2)) + val pq = new BoundedPriorityQueue[(String, Float)](num + 1)(Ordering.by(_._2)) - for(i <- cosVec.indices) { - pq += Tuple2(wordList(i), cosVec(i)) + var j = 0 + while (j < numWords) { + pq += Tuple2(wordList(j), cosineVec(j)) + j += 1 } val scored = pq.toSeq.sortBy(-_._2) @@ -606,7 +607,10 @@ class Word2VecModel private[spark] ( case None => scored } - filtered.take(num).toArray + filtered + .take(num) + .map { case (word, score) => (word, score.toDouble) } + .toArray } /** From d4a637cd46b6dd5cc71ea17a55c4a26186e592c7 Mon Sep 17 00:00:00 2001 From: zero323 Date: Tue, 14 Mar 2017 07:34:44 -0700 Subject: [PATCH 026/512] [SPARK-19940][ML][MINOR] FPGrowthModel.transform should skip duplicated items ## What changes were proposed in this pull request? This commit moved `distinct` in its intended place to avoid duplicated predictions and adds unit test covering the issue. ## How was this patch tested? Unit tests. Author: zero323 Closes #17283 from zero323/SPARK-19940. --- .../scala/org/apache/spark/ml/fpm/FPGrowth.scala | 4 ++-- .../org/apache/spark/ml/fpm/FPGrowthSuite.scala | 14 ++++++++++++++ 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala index 417968d9b817..fa39dd954af5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala @@ -245,10 +245,10 @@ class FPGrowthModel private[ml] ( rule._2.filter(item => !itemset.contains(item)) } else { Seq.empty - }) + }).distinct } else { Seq.empty - }.distinct }, dt) + }}, dt) dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala index 076d55c18054..910d4b07d130 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala @@ -103,6 +103,20 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul FPGrowthSuite.allParamSettings, checkModelData) } + test("FPGrowth prediction should not contain duplicates") { + // This should generate rule 1 -> 3, 2 -> 3 + val dataset = spark.createDataFrame(Seq( + Array("1", "3"), + Array("2", "3") + ).map(Tuple1(_))).toDF("features") + val model = new FPGrowth().fit(dataset) + + val prediction = model.transform( + spark.createDataFrame(Seq(Tuple1(Array("1", "2")))).toDF("features") + ).first().getAs[Seq[String]]("prediction") + + assert(prediction === Seq("3")) + } } object FPGrowthSuite { From 85941ecf28362f35718ebcd3a22dbb17adb49154 Mon Sep 17 00:00:00 2001 From: Menglong TAN Date: Tue, 14 Mar 2017 07:45:42 -0700 Subject: [PATCH 027/512] [SPARK-11569][ML] Fix StringIndexer to handle null value properly ## What changes were proposed in this pull request? This PR is to enhance StringIndexer with NULL values handling. Before the PR, StringIndexer will throw an exception when encounters NULL values. With this PR: - handleInvalid=error: Throw an exception as before - handleInvalid=skip: Skip null values as well as unseen labels - handleInvalid=keep: Give null values an additional index as well as unseen labels BTW, I noticed someone was trying to solve the same problem ( #9920 ) but seems getting no progress or response for a long time. Would you mind to give me a chance to solve it ? I'm eager to help. :-) ## How was this patch tested? new unit tests Author: Menglong TAN Author: Menglong TAN Closes #17233 from crackcell/11569_StringIndexer_NULL. --- .../spark/ml/feature/StringIndexer.scala | 54 +++++++++++-------- .../spark/ml/feature/StringIndexerSuite.scala | 45 ++++++++++++++++ 2 files changed, 77 insertions(+), 22 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index 810b02febbe7..99321bcc7cf9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -39,20 +39,21 @@ import org.apache.spark.util.collection.OpenHashMap private[feature] trait StringIndexerBase extends Params with HasInputCol with HasOutputCol { /** - * Param for how to handle unseen labels. Options are 'skip' (filter out rows with - * unseen labels), 'error' (throw an error), or 'keep' (put unseen labels in a special additional - * bucket, at index numLabels. + * Param for how to handle invalid data (unseen labels or NULL values). + * Options are 'skip' (filter out rows with invalid data), + * 'error' (throw an error), or 'keep' (put invalid data in a special additional + * bucket, at index numLabels). * Default: "error" * @group param */ @Since("1.6.0") val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", "how to handle " + - "unseen labels. Options are 'skip' (filter out rows with unseen labels), " + - "error (throw an error), or 'keep' (put unseen labels in a special additional bucket, " + - "at index numLabels).", + "invalid data (unseen labels or NULL values). " + + "Options are 'skip' (filter out rows with invalid data), error (throw an error), " + + "or 'keep' (put invalid data in a special additional bucket, at index numLabels).", ParamValidators.inArray(StringIndexer.supportedHandleInvalids)) - setDefault(handleInvalid, StringIndexer.ERROR_UNSEEN_LABEL) + setDefault(handleInvalid, StringIndexer.ERROR_INVALID) /** @group getParam */ @Since("1.6.0") @@ -106,7 +107,7 @@ class StringIndexer @Since("1.4.0") ( @Since("2.0.0") override def fit(dataset: Dataset[_]): StringIndexerModel = { transformSchema(dataset.schema, logging = true) - val counts = dataset.select(col($(inputCol)).cast(StringType)) + val counts = dataset.na.drop(Array($(inputCol))).select(col($(inputCol)).cast(StringType)) .rdd .map(_.getString(0)) .countByValue() @@ -125,11 +126,11 @@ class StringIndexer @Since("1.4.0") ( @Since("1.6.0") object StringIndexer extends DefaultParamsReadable[StringIndexer] { - private[feature] val SKIP_UNSEEN_LABEL: String = "skip" - private[feature] val ERROR_UNSEEN_LABEL: String = "error" - private[feature] val KEEP_UNSEEN_LABEL: String = "keep" + private[feature] val SKIP_INVALID: String = "skip" + private[feature] val ERROR_INVALID: String = "error" + private[feature] val KEEP_INVALID: String = "keep" private[feature] val supportedHandleInvalids: Array[String] = - Array(SKIP_UNSEEN_LABEL, ERROR_UNSEEN_LABEL, KEEP_UNSEEN_LABEL) + Array(SKIP_INVALID, ERROR_INVALID, KEEP_INVALID) @Since("1.6.0") override def load(path: String): StringIndexer = super.load(path) @@ -188,7 +189,7 @@ class StringIndexerModel ( transformSchema(dataset.schema, logging = true) val filteredLabels = getHandleInvalid match { - case StringIndexer.KEEP_UNSEEN_LABEL => labels :+ "__unknown" + case StringIndexer.KEEP_INVALID => labels :+ "__unknown" case _ => labels } @@ -196,22 +197,31 @@ class StringIndexerModel ( .withName($(outputCol)).withValues(filteredLabels).toMetadata() // If we are skipping invalid records, filter them out. val (filteredDataset, keepInvalid) = getHandleInvalid match { - case StringIndexer.SKIP_UNSEEN_LABEL => + case StringIndexer.SKIP_INVALID => val filterer = udf { label: String => labelToIndex.contains(label) } - (dataset.where(filterer(dataset($(inputCol)))), false) - case _ => (dataset, getHandleInvalid == StringIndexer.KEEP_UNSEEN_LABEL) + (dataset.na.drop(Array($(inputCol))).where(filterer(dataset($(inputCol)))), false) + case _ => (dataset, getHandleInvalid == StringIndexer.KEEP_INVALID) } val indexer = udf { label: String => - if (labelToIndex.contains(label)) { - labelToIndex(label) - } else if (keepInvalid) { - labels.length + if (label == null) { + if (keepInvalid) { + labels.length + } else { + throw new SparkException("StringIndexer encountered NULL value. To handle or skip " + + "NULLS, try setting StringIndexer.handleInvalid.") + } } else { - throw new SparkException(s"Unseen label: $label. To handle unseen labels, " + - s"set Param handleInvalid to ${StringIndexer.KEEP_UNSEEN_LABEL}.") + if (labelToIndex.contains(label)) { + labelToIndex(label) + } else if (keepInvalid) { + labels.length + } else { + throw new SparkException(s"Unseen label: $label. To handle unseen labels, " + + s"set Param handleInvalid to ${StringIndexer.KEEP_INVALID}.") + } } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala index 188dffb3dd55..8d9042b31e03 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala @@ -122,6 +122,51 @@ class StringIndexerSuite assert(output === expected) } + test("StringIndexer with NULLs") { + val data: Seq[(Int, String)] = Seq((0, "a"), (1, "b"), (2, "b"), (3, null)) + val data2: Seq[(Int, String)] = Seq((0, "a"), (1, "b"), (3, null)) + val df = data.toDF("id", "label") + val df2 = data2.toDF("id", "label") + + val indexer = new StringIndexer() + .setInputCol("label") + .setOutputCol("labelIndex") + + withClue("StringIndexer should throw error when setHandleInvalid=error " + + "when given NULL values") { + intercept[SparkException] { + indexer.setHandleInvalid("error") + indexer.fit(df).transform(df2).collect() + } + } + + indexer.setHandleInvalid("skip") + val transformedSkip = indexer.fit(df).transform(df2) + val attrSkip = Attribute + .fromStructField(transformedSkip.schema("labelIndex")) + .asInstanceOf[NominalAttribute] + assert(attrSkip.values.get === Array("b", "a")) + val outputSkip = transformedSkip.select("id", "labelIndex").rdd.map { r => + (r.getInt(0), r.getDouble(1)) + }.collect().toSet + // a -> 1, b -> 0 + val expectedSkip = Set((0, 1.0), (1, 0.0)) + assert(outputSkip === expectedSkip) + + indexer.setHandleInvalid("keep") + val transformedKeep = indexer.fit(df).transform(df2) + val attrKeep = Attribute + .fromStructField(transformedKeep.schema("labelIndex")) + .asInstanceOf[NominalAttribute] + assert(attrKeep.values.get === Array("b", "a", "__unknown")) + val outputKeep = transformedKeep.select("id", "labelIndex").rdd.map { r => + (r.getInt(0), r.getDouble(1)) + }.collect().toSet + // a -> 1, b -> 0, null -> 2 + val expectedKeep = Set((0, 1.0), (1, 0.0), (3, 2.0)) + assert(outputKeep === expectedKeep) + } + test("StringIndexerModel should keep silent if the input column does not exist.") { val indexerModel = new StringIndexerModel("indexer", Array("a", "b", "c")) .setInputCol("label") From a02a0b1703dafab541c9b57939e3ed37e412d0f8 Mon Sep 17 00:00:00 2001 From: jiangxingbo Date: Tue, 14 Mar 2017 10:13:50 -0700 Subject: [PATCH 028/512] [SPARK-18961][SQL] Support `SHOW TABLE EXTENDED ... PARTITION` statement ## What changes were proposed in this pull request? We should support the statement `SHOW TABLE EXTENDED LIKE 'table_identifier' PARTITION(partition_spec)`, just like that HIVE does. When partition is specified, the `SHOW TABLE EXTENDED` command should output the information of the partitions instead of the tables. Note that in this statement, we require exact matched partition spec. For example: ``` CREATE TABLE show_t1(a String, b Int) PARTITIONED BY (c String, d String); ALTER TABLE show_t1 ADD PARTITION (c='Us', d=1) PARTITION (c='Us', d=22); -- Output the extended information of Partition(c='Us', d=1) SHOW TABLE EXTENDED LIKE 'show_t1' PARTITION(c='Us', d=1); -- Throw an AnalysisException SHOW TABLE EXTENDED LIKE 'show_t1' PARTITION(c='Us'); ``` ## How was this patch tested? Add new test sqls in file `show-tables.sql`. Add new test case in `DDLSuite`. Author: jiangxingbo Closes #16373 from jiangxb1987/show-partition-extended. --- .../spark/sql/execution/QueryExecution.scala | 4 +- .../spark/sql/execution/SparkSqlParser.scala | 11 +- .../spark/sql/execution/command/tables.scala | 44 ++++-- .../sql-tests/inputs/show-tables.sql | 15 +- .../sql-tests/results/show-tables.sql.out | 133 ++++++++++++++---- .../apache/spark/sql/SQLQueryTestSuite.scala | 5 +- .../sql/execution/command/DDLSuite.scala | 34 ----- 7 files changed, 163 insertions(+), 83 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index 9a3656ddc79f..8e8210e334a1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -127,8 +127,8 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) { .map(s => String.format(s"%-20s", s)) .mkString("\t") } - // SHOW TABLES in Hive only output table names, while ours outputs database, table name, isTemp. - case command: ExecutedCommandExec if command.cmd.isInstanceOf[ShowTablesCommand] => + // SHOW TABLES in Hive only output table names, while ours output database, table name, isTemp. + case command @ ExecutedCommandExec(s: ShowTablesCommand) if !s.isExtended => command.executeCollect().map(_.getString(1)) case other => val result: Seq[Seq[Any]] = other.executeCollectPublic().map(_.toSeq).toSeq diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index 00d1d6d2701f..abea7a3bcf14 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -134,7 +134,8 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { ShowTablesCommand( Option(ctx.db).map(_.getText), Option(ctx.pattern).map(string), - isExtended = false) + isExtended = false, + partitionSpec = None) } /** @@ -146,14 +147,12 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { * }}} */ override def visitShowTable(ctx: ShowTableContext): LogicalPlan = withOrigin(ctx) { - if (ctx.partitionSpec != null) { - operationNotAllowed("SHOW TABLE EXTENDED ... PARTITION", ctx) - } - + val partitionSpec = Option(ctx.partitionSpec).map(visitNonOptionalPartitionSpec) ShowTablesCommand( Option(ctx.db).map(_.getText), Option(ctx.pattern).map(string), - isExtended = true) + isExtended = true, + partitionSpec = partitionSpec) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala index 86394ff23e37..beb3dcafd64f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala @@ -616,13 +616,15 @@ case class DescribeTableCommand( * The syntax of using this command in SQL is: * {{{ * SHOW TABLES [(IN|FROM) database_name] [[LIKE] 'identifier_with_wildcards']; - * SHOW TABLE EXTENDED [(IN|FROM) database_name] LIKE 'identifier_with_wildcards'; + * SHOW TABLE EXTENDED [(IN|FROM) database_name] LIKE 'identifier_with_wildcards' + * [PARTITION(partition_spec)]; * }}} */ case class ShowTablesCommand( databaseName: Option[String], tableIdentifierPattern: Option[String], - isExtended: Boolean = false) extends RunnableCommand { + isExtended: Boolean = false, + partitionSpec: Option[TablePartitionSpec] = None) extends RunnableCommand { // The result of SHOW TABLES/SHOW TABLE has three basic columns: database, tableName and // isTemporary. If `isExtended` is true, append column `information` to the output columns. @@ -642,18 +644,34 @@ case class ShowTablesCommand( // instead of calling tables in sparkSession. val catalog = sparkSession.sessionState.catalog val db = databaseName.getOrElse(catalog.getCurrentDatabase) - val tables = - tableIdentifierPattern.map(catalog.listTables(db, _)).getOrElse(catalog.listTables(db)) - tables.map { tableIdent => - val database = tableIdent.database.getOrElse("") - val tableName = tableIdent.table - val isTemp = catalog.isTemporaryTable(tableIdent) - if (isExtended) { - val information = catalog.getTempViewOrPermanentTableMetadata(tableIdent).toString - Row(database, tableName, isTemp, s"${information}\n") - } else { - Row(database, tableName, isTemp) + if (partitionSpec.isEmpty) { + // Show the information of tables. + val tables = + tableIdentifierPattern.map(catalog.listTables(db, _)).getOrElse(catalog.listTables(db)) + tables.map { tableIdent => + val database = tableIdent.database.getOrElse("") + val tableName = tableIdent.table + val isTemp = catalog.isTemporaryTable(tableIdent) + if (isExtended) { + val information = catalog.getTempViewOrPermanentTableMetadata(tableIdent).toString + Row(database, tableName, isTemp, s"$information\n") + } else { + Row(database, tableName, isTemp) + } } + } else { + // Show the information of partitions. + // + // Note: tableIdentifierPattern should be non-empty, otherwise a [[ParseException]] + // should have been thrown by the sql parser. + val tableIdent = TableIdentifier(tableIdentifierPattern.get, Some(db)) + val table = catalog.getTableMetadata(tableIdent).identifier + val partition = catalog.getPartition(tableIdent, partitionSpec.get) + val database = table.database.getOrElse("") + val tableName = table.table + val isTemp = catalog.isTemporaryTable(table) + val information = partition.toString + Seq(Row(database, tableName, isTemp, s"$information\n")) } } } diff --git a/sql/core/src/test/resources/sql-tests/inputs/show-tables.sql b/sql/core/src/test/resources/sql-tests/inputs/show-tables.sql index 10c379dfa014..3c77c9977d80 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/show-tables.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/show-tables.sql @@ -17,10 +17,21 @@ SHOW TABLES LIKE 'show_t1*|show_t2*'; SHOW TABLES IN showdb 'show_t*'; -- SHOW TABLE EXTENDED --- Ignore these because there exist timestamp results, e.g. `Created`. --- SHOW TABLE EXTENDED LIKE 'show_t*'; +SHOW TABLE EXTENDED LIKE 'show_t*'; SHOW TABLE EXTENDED; + +-- SHOW TABLE EXTENDED ... PARTITION +SHOW TABLE EXTENDED LIKE 'show_t1' PARTITION(c='Us', d=1); +-- Throw a ParseException if table name is not specified. +SHOW TABLE EXTENDED PARTITION(c='Us', d=1); +-- Don't support regular expression for table name if a partition specification is present. +SHOW TABLE EXTENDED LIKE 'show_t*' PARTITION(c='Us', d=1); +-- Partition specification is not complete. SHOW TABLE EXTENDED LIKE 'show_t1' PARTITION(c='Us'); +-- Partition specification is invalid. +SHOW TABLE EXTENDED LIKE 'show_t1' PARTITION(a='Us', d=1); +-- Partition specification doesn't exist. +SHOW TABLE EXTENDED LIKE 'show_t1' PARTITION(c='Ch', d=1); -- Clean Up DROP TABLE show_t1; diff --git a/sql/core/src/test/resources/sql-tests/results/show-tables.sql.out b/sql/core/src/test/resources/sql-tests/results/show-tables.sql.out index 3d287f43accc..6d62e6092147 100644 --- a/sql/core/src/test/resources/sql-tests/results/show-tables.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/show-tables.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 20 +-- Number of queries: 26 -- !query 0 @@ -114,76 +114,159 @@ show_t3 -- !query 12 -SHOW TABLE EXTENDED +SHOW TABLE EXTENDED LIKE 'show_t*' -- !query 12 schema -struct<> +struct -- !query 12 output -org.apache.spark.sql.catalyst.parser.ParseException - -mismatched input '' expecting 'LIKE'(line 1, pos 19) - -== SQL == -SHOW TABLE EXTENDED --------------------^^^ +show_t3 true CatalogTable( + Table: `show_t3` + Created: + Last Access: + Type: VIEW + Schema: [StructField(e,IntegerType,true)] + Storage()) + +showdb show_t1 false CatalogTable( + Table: `showdb`.`show_t1` + Created: + Last Access: + Type: MANAGED + Schema: [StructField(a,StringType,true), StructField(b,IntegerType,true), StructField(c,StringType,true), StructField(d,StringType,true)] + Provider: parquet + Partition Columns: [`c`, `d`] + Storage(Location: sql/core/spark-warehouse/showdb.db/show_t1) + Partition Provider: Catalog) + +showdb show_t2 false CatalogTable( + Table: `showdb`.`show_t2` + Created: + Last Access: + Type: MANAGED + Schema: [StructField(b,StringType,true), StructField(d,IntegerType,true)] + Provider: parquet + Storage(Location: sql/core/spark-warehouse/showdb.db/show_t2)) -- !query 13 -SHOW TABLE EXTENDED LIKE 'show_t1' PARTITION(c='Us') +SHOW TABLE EXTENDED -- !query 13 schema struct<> -- !query 13 output org.apache.spark.sql.catalyst.parser.ParseException -Operation not allowed: SHOW TABLE EXTENDED ... PARTITION(line 1, pos 0) +mismatched input '' expecting 'LIKE'(line 1, pos 19) == SQL == -SHOW TABLE EXTENDED LIKE 'show_t1' PARTITION(c='Us') -^^^ +SHOW TABLE EXTENDED +-------------------^^^ -- !query 14 -DROP TABLE show_t1 +SHOW TABLE EXTENDED LIKE 'show_t1' PARTITION(c='Us', d=1) -- !query 14 schema -struct<> +struct -- !query 14 output - +showdb show_t1 false CatalogPartition( + Partition Values: [c=Us, d=1] + Storage(Location: sql/core/spark-warehouse/showdb.db/show_t1/c=Us/d=1) + Partition Parameters:{}) -- !query 15 -DROP TABLE show_t2 +SHOW TABLE EXTENDED PARTITION(c='Us', d=1) -- !query 15 schema struct<> -- !query 15 output +org.apache.spark.sql.catalyst.parser.ParseException + +mismatched input 'PARTITION' expecting 'LIKE'(line 1, pos 20) +== SQL == +SHOW TABLE EXTENDED PARTITION(c='Us', d=1) +--------------------^^^ -- !query 16 -DROP VIEW show_t3 +SHOW TABLE EXTENDED LIKE 'show_t*' PARTITION(c='Us', d=1) -- !query 16 schema struct<> -- !query 16 output - +org.apache.spark.sql.catalyst.analysis.NoSuchTableException +Table or view 'show_t*' not found in database 'showdb'; -- !query 17 -DROP VIEW global_temp.show_t4 +SHOW TABLE EXTENDED LIKE 'show_t1' PARTITION(c='Us') -- !query 17 schema struct<> -- !query 17 output - +org.apache.spark.sql.AnalysisException +Partition spec is invalid. The spec (c) must match the partition spec (c, d) defined in table '`showdb`.`show_t1`'; -- !query 18 -USE default +SHOW TABLE EXTENDED LIKE 'show_t1' PARTITION(a='Us', d=1) -- !query 18 schema struct<> -- !query 18 output - +org.apache.spark.sql.AnalysisException +Partition spec is invalid. The spec (a, d) must match the partition spec (c, d) defined in table '`showdb`.`show_t1`'; -- !query 19 -DROP DATABASE showdb +SHOW TABLE EXTENDED LIKE 'show_t1' PARTITION(c='Ch', d=1) -- !query 19 schema struct<> -- !query 19 output +org.apache.spark.sql.catalyst.analysis.NoSuchPartitionException +Partition not found in table 'show_t1' database 'showdb': +c -> Ch +d -> 1; + + +-- !query 20 +DROP TABLE show_t1 +-- !query 20 schema +struct<> +-- !query 20 output + + + +-- !query 21 +DROP TABLE show_t2 +-- !query 21 schema +struct<> +-- !query 21 output + + + +-- !query 22 +DROP VIEW show_t3 +-- !query 22 schema +struct<> +-- !query 22 output + + + +-- !query 23 +DROP VIEW global_temp.show_t4 +-- !query 23 schema +struct<> +-- !query 23 output + + + +-- !query 24 +USE default +-- !query 24 schema +struct<> +-- !query 24 output + + + +-- !query 25 +DROP DATABASE showdb +-- !query 25 schema +struct<> +-- !query 25 output diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala index 68ababcd1102..c285995514c8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala @@ -222,7 +222,10 @@ class SQLQueryTestSuite extends QueryTest with SharedSQLContext { val df = session.sql(sql) val schema = df.schema // Get answer, but also get rid of the #1234 expression ids that show up in explain plans - val answer = df.queryExecution.hiveResultString().map(_.replaceAll("#\\d+", "#x")) + val answer = df.queryExecution.hiveResultString().map(_.replaceAll("#\\d+", "#x") + .replaceAll("Location: .*/sql/core/", "Location: sql/core/") + .replaceAll("Created: .*\n", "Created: \n") + .replaceAll("Last Access: .*\n", "Last Access: \n")) // If the output is not pre-sorted, sort it. if (isSorted(df.queryExecution.analyzed)) (schema, answer) else (schema, answer.sorted) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index 0666f446f3b5..6eed10ec5146 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -977,40 +977,6 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { testRenamePartitions(isDatasourceTable = false) } - test("show table extended") { - withTempView("show1a", "show2b") { - sql( - """ - |CREATE TEMPORARY VIEW show1a - |USING org.apache.spark.sql.sources.DDLScanSource - |OPTIONS ( - | From '1', - | To '10', - | Table 'test1' - | - |) - """.stripMargin) - sql( - """ - |CREATE TEMPORARY VIEW show2b - |USING org.apache.spark.sql.sources.DDLScanSource - |OPTIONS ( - | From '1', - | To '10', - | Table 'test1' - |) - """.stripMargin) - assert( - sql("SHOW TABLE EXTENDED LIKE 'show*'").count() >= 2) - assert( - sql("SHOW TABLE EXTENDED LIKE 'show*'").schema == - StructType(StructField("database", StringType, false) :: - StructField("tableName", StringType, false) :: - StructField("isTemporary", BooleanType, false) :: - StructField("information", StringType, false) :: Nil)) - } - } - test("show databases") { sql("CREATE DATABASE showdb2B") sql("CREATE DATABASE showdb1A") From 6325a2f82a95a63bee020122620bc4f5fd25d059 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Tue, 14 Mar 2017 18:51:05 +0100 Subject: [PATCH 029/512] [SPARK-19923][SQL] Remove unnecessary type conversions per call in Hive ## What changes were proposed in this pull request? This pr removed unnecessary type conversions per call in Hive: https://github.com/apache/spark/blob/master/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala#L116 ## How was this patch tested? Existing tests Author: Takeshi Yamamuro Closes #17264 from maropu/SPARK-19923. --- .../scala/org/apache/spark/sql/hive/hiveUDFs.scala | 3 ++- .../apache/spark/sql/hive/orc/OrcFileFormat.scala | 13 +++++++++---- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala index 506949cb682b..51c814cf32a8 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala @@ -108,12 +108,13 @@ private[hive] case class HiveSimpleUDF( private[hive] class DeferredObjectAdapter(oi: ObjectInspector, dataType: DataType) extends DeferredObject with HiveInspectors { + private val wrapper = wrapperFor(oi, dataType) private var func: () => Any = _ def set(func: () => Any): Unit = { this.func = func } override def prepare(i: Int): Unit = {} - override def get(): AnyRef = wrap(func(), oi, dataType) + override def get(): AnyRef = wrapper(func()).asInstanceOf[AnyRef] } private[hive] case class HiveGenericUDF( diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala index f496c01ce9ff..3a34ec55c8b0 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala @@ -20,6 +20,8 @@ package org.apache.spark.sql.hive.orc import java.net.URI import java.util.Properties +import scala.collection.JavaConverters._ + import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.hive.conf.HiveConf.ConfVars @@ -196,6 +198,11 @@ private[orc] class OrcSerializer(dataSchema: StructType, conf: Configuration) private[this] val cachedOrcStruct = structOI.create().asInstanceOf[OrcStruct] + // Wrapper functions used to wrap Spark SQL input arguments into Hive specific format + private[this] val wrappers = dataSchema.zip(structOI.getAllStructFieldRefs().asScala.toSeq).map { + case (f, i) => wrapperFor(i.getFieldObjectInspector, f.dataType) + } + private[this] def wrapOrcStruct( struct: OrcStruct, oi: SettableStructObjectInspector, @@ -208,10 +215,8 @@ private[orc] class OrcSerializer(dataSchema: StructType, conf: Configuration) oi.setStructFieldData( struct, fieldRefs.get(i), - wrap( - row.get(i, dataSchema(i).dataType), - fieldRefs.get(i).getFieldObjectInspector, - dataSchema(i).dataType)) + wrappers(i)(row.get(i, dataSchema(i).dataType)) + ) i += 1 } } From e04c05cf41a125b0526f59f9b9e7fdf0b78b8b21 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Tue, 14 Mar 2017 18:52:16 +0100 Subject: [PATCH 030/512] [SPARK-19933][SQL] Do not change output of a subquery ## What changes were proposed in this pull request? The `RemoveRedundantAlias` rule can change the output attributes (the expression id's to be precise) of a query by eliminating the redundant alias producing them. This is no problem for a regular query, but can cause problems for correlated subqueries: The attributes produced by the subquery are used in the parent plan; changing them will break the parent plan. This PR fixes this by wrapping a subquery in a `Subquery` top level node when it gets optimized. The `RemoveRedundantAlias` rule now recognizes `Subquery` and makes sure that the output attributes of the `Subquery` node are retained. ## How was this patch tested? Added a test case to `RemoveRedundantAliasAndProjectSuite` and added a regression test to `SubquerySuite`. Author: Herman van Hovell Closes #17278 from hvanhovell/SPARK-19933. --- .../spark/sql/catalyst/optimizer/Optimizer.scala | 15 ++++++++++++--- .../plans/logical/basicLogicalOperators.scala | 8 ++++++++ .../RemoveRedundantAliasAndProjectSuite.scala | 8 ++++++++ .../org/apache/spark/sql/SubquerySuite.scala | 14 ++++++++++++++ 4 files changed, 42 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index e9dbded3d4d0..c8ed4190a13a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -142,7 +142,8 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf) object OptimizeSubqueries extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { case s: SubqueryExpression => - s.withNewPlan(Optimizer.this.execute(s.plan)) + val Subquery(newPlan) = Optimizer.this.execute(Subquery(s.plan)) + s.withNewPlan(newPlan) } } } @@ -187,7 +188,10 @@ object RemoveRedundantAliases extends Rule[LogicalPlan] { // If the alias name is different from attribute name, we can't strip it either, or we // may accidentally change the output schema name of the root plan. case a @ Alias(attr: Attribute, name) - if a.metadata == Metadata.empty && name == attr.name && !blacklist.contains(attr) => + if a.metadata == Metadata.empty && + name == attr.name && + !blacklist.contains(attr) && + !blacklist.contains(a) => attr case a => a } @@ -195,10 +199,15 @@ object RemoveRedundantAliases extends Rule[LogicalPlan] { /** * Remove redundant alias expression from a LogicalPlan and its subtree. A blacklist is used to * prevent the removal of seemingly redundant aliases used to deduplicate the input for a (self) - * join. + * join or to prevent the removal of top-level subquery attributes. */ private def removeRedundantAliases(plan: LogicalPlan, blacklist: AttributeSet): LogicalPlan = { plan match { + // We want to keep the same output attributes for subqueries. This means we cannot remove + // the aliases that produce these attributes + case Subquery(child) => + Subquery(removeRedundantAliases(child, blacklist ++ child.outputSet)) + // A join has to be treated differently, because the left and the right side of the join are // not allowed to use the same attributes. We use a blacklist to prevent us from creating a // situation in which this happens; the rule will only remove an alias if its child diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 31b6ed48a223..5cbf263d1ce4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -38,6 +38,14 @@ case class ReturnAnswer(child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output } +/** + * This node is inserted at the top of a subquery when it is optimized. This makes sure we can + * recognize a subquery as such, and it allows us to write subquery aware transformations. + */ +case class Subquery(child: LogicalPlan) extends UnaryNode { + override def output: Seq[Attribute] = child.output +} + case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = projectList.map(_.toAttribute) override def maxRows: Option[Long] = child.maxRows diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantAliasAndProjectSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantAliasAndProjectSuite.scala index c01ea01ec680..1973b5abb462 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantAliasAndProjectSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantAliasAndProjectSuite.scala @@ -116,4 +116,12 @@ class RemoveRedundantAliasAndProjectSuite extends PlanTest with PredicateHelper val expected = relation.window(Seq('b), Seq('a), Seq()).analyze comparePlans(optimized, expected) } + + test("do not remove output attributes from a subquery") { + val relation = LocalRelation('a.int, 'b.int) + val query = Subquery(relation.select('a as "a", 'b as "b").where('b < 10).select('a).analyze) + val optimized = Optimize.execute(query) + val expected = Subquery(relation.select('a as "a", 'b).where('b < 10).select('a).analyze) + comparePlans(optimized, expected) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala index 6f1cd49c08ee..5fe6667ceca1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala @@ -830,4 +830,18 @@ class SubquerySuite extends QueryTest with SharedSQLContext { Row(1) :: Row(0) :: Nil) } } + + test("SPARK-19933 Do not eliminate top-level aliases in sub-queries") { + withTempView("t1", "t2") { + spark.range(4).createOrReplaceTempView("t1") + checkAnswer( + sql("select * from t1 where id in (select id as id from t1)"), + Row(0) :: Row(1) :: Row(2) :: Row(3) :: Nil) + + spark.range(2).createOrReplaceTempView("t2") + checkAnswer( + sql("select * from t1 where id in (select id as id from t2)"), + Row(0) :: Row(1) :: Nil) + } + } } From 6eac96823c7b244773bd810812b369e336a65837 Mon Sep 17 00:00:00 2001 From: Nattavut Sutyanyong Date: Tue, 14 Mar 2017 20:34:59 +0100 Subject: [PATCH 031/512] [SPARK-18966][SQL] NOT IN subquery with correlated expressions may return incorrect result ## What changes were proposed in this pull request? This PR fixes the following problem: ```` Seq((1, 2)).toDF("a1", "a2").createOrReplaceTempView("a") Seq[(java.lang.Integer, java.lang.Integer)]((1, null)).toDF("b1", "b2").createOrReplaceTempView("b") // The expected result is 1 row of (1,2) as shown in the next statement. sql("select * from a where a1 not in (select b1 from b where b2 = a2)").show +---+---+ | a1| a2| +---+---+ +---+---+ sql("select * from a where a1 not in (select b1 from b where b2 = 2)").show +---+---+ | a1| a2| +---+---+ | 1| 2| +---+---+ ```` There are a number of scenarios to consider: 1. When the correlated predicate yields a match (i.e., B.B2 = A.A2) 1.1. When the NOT IN expression yields a match (i.e., A.A1 = B.B1) 1.2. When the NOT IN expression yields no match (i.e., A.A1 = B.B1 returns false) 1.3. When A.A1 is null 1.4. When B.B1 is null 1.4.1. When A.A1 is not null 1.4.2. When A.A1 is null 2. When the correlated predicate yields no match (i.e.,B.B2 = A.A2 is false or unknown) 2.1. When B.B2 is null and A.A2 is null 2.2. When B.B2 is null and A.A2 is not null 2.3. When the value of A.A2 does not match any of B.B2 ```` A.A1 A.A2 B.B1 B.B2 ----- ----- ----- ----- 1 1 1 1 (1.1) 2 1 (1.2) null 1 (1.3) 1 3 null 3 (1.4.1) null 3 (1.4.2) 1 null 1 null (2.1) null 2 (2.2 & 2.3) ```` We can divide the evaluation of the above correlated NOT IN subquery into 2 groups:- Group 1: The rows in A when there is a match from the correlated predicate (A.A1 = B.B1) In this case, the result of the subquery is not empty and the semantics of the NOT IN depends solely on the evaluation of the equality comparison of the columns of NOT IN, i.e., A1 = B1, which says - If A.A1 is null, the row is filtered (1.3 and 1.4.2) - If A.A1 = B.B1, the row is filtered (1.1) - If B.B1 is null, any rows of A in the same group (A.A2 = B.B2) is filtered (1.4.1 & 1.4.2) - Otherwise, the row is qualified. Hence, in this group, the result is the row from (1.2). Group 2: The rows in A when there is no match from the correlated predicate (A.A2 = B.B2) In this case, all the rows in A, including the rows where A.A1, are qualified because the subquery returns an empty set and by the semantics of the NOT IN, all rows from the parent side qualifies as the result set, that is, the rows from (2.1, 2.2 and 2.3). In conclusion, the correct result set of the above query is ```` A.A1 A.A2 ----- ----- 2 1 (1.2) 1 null (2.1) null 2 (2.2 & 2.3) ```` ## How was this patch tested? unit tests, regression tests, and new test cases focusing on the problem being fixed. Author: Nattavut Sutyanyong Closes #17294 from nsyca/18966. --- .../sql/catalyst/optimizer/subquery.scala | 13 +++-- .../inputs/subquery/in-subquery/simple-in.sql | 24 +++++++++ .../subquery/in-subquery/simple-in.sql.out | 50 ++++++++++++++++++- 3 files changed, 82 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala index ba3fd1d5f802..2a3e07aebe70 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala @@ -80,14 +80,19 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { // Note that will almost certainly be planned as a Broadcast Nested Loop join. // Use EXISTS if performance matters to you. val inConditions = getValueExpression(value).zip(sub.output).map(EqualTo.tupled) - val (joinCond, outerPlan) = rewriteExistentialExpr(inConditions ++ conditions, p) + val (joinCond, outerPlan) = rewriteExistentialExpr(inConditions, p) // Expand the NOT IN expression with the NULL-aware semantic // to its full form. That is from: - // (a1,b1,...) = (a2,b2,...) + // (a1,a2,...) = (b1,b2,...) // to - // (a1=a2 OR isnull(a1=a2)) AND (b1=b2 OR isnull(b1=b2)) AND ... + // (a1=b1 OR isnull(a1=b1)) AND (a2=b2 OR isnull(a2=b2)) AND ... val joinConds = splitConjunctivePredicates(joinCond.get) - val pairs = joinConds.map(c => Or(c, IsNull(c))).reduceLeft(And) + // After that, add back the correlated join predicate(s) in the subquery + // Example: + // SELECT ... FROM A WHERE A.A1 NOT IN (SELECT B.B1 FROM B WHERE B.B2 = A.A2 AND B.B3 > 1) + // will have the final conditions in the LEFT ANTI as + // (A.A1 = B.B1 OR ISNULL(A.A1 = B.B1)) AND (B.B2 = A.A2) + val pairs = (joinConds.map(c => Or(c, IsNull(c))) ++ conditions).reduceLeft(And) Join(outerPlan, sub, LeftAnti, Option(pairs)) case (p, predicate) => val (newCond, inputPlan) = rewriteExistentialExpr(Seq(predicate), p) diff --git a/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/simple-in.sql b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/simple-in.sql index 20370b045e80..f19567d2fac2 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/simple-in.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/simple-in.sql @@ -109,4 +109,28 @@ FROM t1 WHERE t1a NOT IN (SELECT t2a FROM t2); +-- DDLs +create temporary view a as select * from values + (1, 1), (2, 1), (null, 1), (1, 3), (null, 3), (1, null), (null, 2) + as a(a1, a2); +create temporary view b as select * from values + (1, 1, 2), (null, 3, 2), (1, null, 2), (1, 2, null) + as b(b1, b2, b3); + +-- TC 02.01 +SELECT a1, a2 +FROM a +WHERE a1 NOT IN (SELECT b.b1 + FROM b + WHERE a.a2 = b.b2) +; + +-- TC 02.02 +SELECT a1, a2 +FROM a +WHERE a1 NOT IN (SELECT b.b1 + FROM b + WHERE a.a2 = b.b2 + AND b.b3 > 1) +; diff --git a/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/simple-in.sql.out b/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/simple-in.sql.out index 66493d7fcc92..d69b4bcf185c 100644 --- a/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/simple-in.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/simple-in.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 10 +-- Number of queries: 14 -- !query 0 @@ -174,3 +174,51 @@ t1a 6 2014-04-04 01:02:00.001 t1d 10 2015-05-04 01:01:00 t1d NULL 2014-06-04 01:01:00 t1d NULL 2014-07-04 01:02:00.001 + + +-- !query 10 +create temporary view a as select * from values + (1, 1), (2, 1), (null, 1), (1, 3), (null, 3), (1, null), (null, 2) + as a(a1, a2) +-- !query 10 schema +struct<> +-- !query 10 output + + + +-- !query 11 +create temporary view b as select * from values + (1, 1, 2), (null, 3, 2), (1, null, 2), (1, 2, null) + as b(b1, b2, b3) +-- !query 11 schema +struct<> +-- !query 11 output + + + +-- !query 12 +SELECT a1, a2 +FROM a +WHERE a1 NOT IN (SELECT b.b1 + FROM b + WHERE a.a2 = b.b2) +-- !query 12 schema +struct +-- !query 12 output +1 NULL +2 1 + + +-- !query 13 +SELECT a1, a2 +FROM a +WHERE a1 NOT IN (SELECT b.b1 + FROM b + WHERE a.a2 = b.b2 + AND b.b3 > 1) +-- !query 13 schema +struct +-- !query 13 output +1 NULL +2 1 +NULL 2 From 7ded39c223429265b23940ca8244660dbee8320c Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Tue, 14 Mar 2017 13:57:23 -0700 Subject: [PATCH 032/512] [SPARK-19817][SQL] Make it clear that `timeZone` option is a general option in DataFrameReader/Writer. ## What changes were proposed in this pull request? As timezone setting can also affect partition values, it works for all formats, we should make it clear. ## How was this patch tested? Existing tests. Author: Takuya UESHIN Closes #17281 from ueshin/issues/SPARK-19817. --- python/pyspark/sql/readwriter.py | 46 +++++++++++-------- .../sql/catalyst/catalog/interface.scala | 3 +- .../spark/sql/catalyst/json/JSONOptions.scala | 5 +- .../sql/catalyst/util/DateTimeUtils.scala | 2 + .../expressions/JsonExpressionsSuite.scala | 9 ++-- .../apache/spark/sql/DataFrameReader.scala | 22 +++++++-- .../apache/spark/sql/DataFrameWriter.scala | 22 +++++++-- .../execution/OptimizeMetadataOnlyQuery.scala | 2 +- .../datasources/FileFormatWriter.scala | 2 +- .../PartitioningAwareFileIndex.scala | 2 +- .../datasources/csv/CSVOptions.scala | 5 +- .../execution/datasources/csv/CSVSuite.scala | 5 +- .../datasources/json/JsonSuite.scala | 4 +- .../ParquetPartitionDiscoverySuite.scala | 11 +++-- .../sql/sources/PartitionedWriteSuite.scala | 4 +- .../sql/sources/ResolvedDataSourceSuite.scala | 3 +- .../spark/sql/hive/HiveExternalCatalog.scala | 2 +- 17 files changed, 101 insertions(+), 48 deletions(-) diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 4354345ebc55..705803791d89 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -109,6 +109,11 @@ def schema(self, schema): @since(1.5) def option(self, key, value): """Adds an input option for the underlying data source. + + You can set the following option(s) for reading files: + * ``timeZone``: sets the string that indicates a timezone to be used to parse timestamps + in the JSON/CSV datasources or parttion values. + If it isn't set, it uses the default value, session local timezone. """ self._jreader = self._jreader.option(key, to_str(value)) return self @@ -116,6 +121,11 @@ def option(self, key, value): @since(1.4) def options(self, **options): """Adds input options for the underlying data source. + + You can set the following option(s) for reading files: + * ``timeZone``: sets the string that indicates a timezone to be used to parse timestamps + in the JSON/CSV datasources or parttion values. + If it isn't set, it uses the default value, session local timezone. """ for k in options: self._jreader = self._jreader.option(k, to_str(options[k])) @@ -159,7 +169,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, allowComments=None, allowUnquotedFieldNames=None, allowSingleQuotes=None, allowNumericLeadingZero=None, allowBackslashEscapingAnyCharacter=None, mode=None, columnNameOfCorruptRecord=None, dateFormat=None, timestampFormat=None, - timeZone=None, wholeFile=None): + wholeFile=None): """ Loads JSON files and returns the results as a :class:`DataFrame`. @@ -214,8 +224,6 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, formats follow the formats at ``java.text.SimpleDateFormat``. This applies to timestamp type. If None is set, it uses the default value, ``yyyy-MM-dd'T'HH:mm:ss.SSSZZ``. - :param timeZone: sets the string that indicates a timezone to be used to parse timestamps. - If None is set, it uses the default value, session local timezone. :param wholeFile: parse one record, which may span multiple lines, per file. If None is set, it uses the default value, ``false``. @@ -234,7 +242,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, allowSingleQuotes=allowSingleQuotes, allowNumericLeadingZero=allowNumericLeadingZero, allowBackslashEscapingAnyCharacter=allowBackslashEscapingAnyCharacter, mode=mode, columnNameOfCorruptRecord=columnNameOfCorruptRecord, dateFormat=dateFormat, - timestampFormat=timestampFormat, timeZone=timeZone, wholeFile=wholeFile) + timestampFormat=timestampFormat, wholeFile=wholeFile) if isinstance(path, basestring): path = [path] if type(path) == list: @@ -307,7 +315,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non comment=None, header=None, inferSchema=None, ignoreLeadingWhiteSpace=None, ignoreTrailingWhiteSpace=None, nullValue=None, nanValue=None, positiveInf=None, negativeInf=None, dateFormat=None, timestampFormat=None, maxColumns=None, - maxCharsPerColumn=None, maxMalformedLogPerPartition=None, mode=None, timeZone=None, + maxCharsPerColumn=None, maxMalformedLogPerPartition=None, mode=None, columnNameOfCorruptRecord=None, wholeFile=None): """Loads a CSV file and returns the result as a :class:`DataFrame`. @@ -367,8 +375,6 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non uses the default value, ``10``. :param mode: allows a mode for dealing with corrupt records during parsing. If None is set, it uses the default value, ``PERMISSIVE``. - :param timeZone: sets the string that indicates a timezone to be used to parse timestamps. - If None is set, it uses the default value, session local timezone. * ``PERMISSIVE`` : sets other fields to ``null`` when it meets a corrupted \ record, and puts the malformed string into a field configured by \ @@ -399,7 +405,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non nanValue=nanValue, positiveInf=positiveInf, negativeInf=negativeInf, dateFormat=dateFormat, timestampFormat=timestampFormat, maxColumns=maxColumns, maxCharsPerColumn=maxCharsPerColumn, - maxMalformedLogPerPartition=maxMalformedLogPerPartition, mode=mode, timeZone=timeZone, + maxMalformedLogPerPartition=maxMalformedLogPerPartition, mode=mode, columnNameOfCorruptRecord=columnNameOfCorruptRecord, wholeFile=wholeFile) if isinstance(path, basestring): path = [path] @@ -521,6 +527,11 @@ def format(self, source): @since(1.5) def option(self, key, value): """Adds an output option for the underlying data source. + + You can set the following option(s) for writing files: + * ``timeZone``: sets the string that indicates a timezone to be used to format + timestamps in the JSON/CSV datasources or parttion values. + If it isn't set, it uses the default value, session local timezone. """ self._jwrite = self._jwrite.option(key, to_str(value)) return self @@ -528,6 +539,11 @@ def option(self, key, value): @since(1.4) def options(self, **options): """Adds output options for the underlying data source. + + You can set the following option(s) for writing files: + * ``timeZone``: sets the string that indicates a timezone to be used to format + timestamps in the JSON/CSV datasources or parttion values. + If it isn't set, it uses the default value, session local timezone. """ for k in options: self._jwrite = self._jwrite.option(k, to_str(options[k])) @@ -619,8 +635,7 @@ def saveAsTable(self, name, format=None, mode=None, partitionBy=None, **options) self._jwrite.saveAsTable(name) @since(1.4) - def json(self, path, mode=None, compression=None, dateFormat=None, timestampFormat=None, - timeZone=None): + def json(self, path, mode=None, compression=None, dateFormat=None, timestampFormat=None): """Saves the content of the :class:`DataFrame` in JSON format at the specified path. :param path: the path in any Hadoop supported file system @@ -641,15 +656,12 @@ def json(self, path, mode=None, compression=None, dateFormat=None, timestampForm formats follow the formats at ``java.text.SimpleDateFormat``. This applies to timestamp type. If None is set, it uses the default value, ``yyyy-MM-dd'T'HH:mm:ss.SSSZZ``. - :param timeZone: sets the string that indicates a timezone to be used to format timestamps. - If None is set, it uses the default value, session local timezone. >>> df.write.json(os.path.join(tempfile.mkdtemp(), 'data')) """ self.mode(mode) self._set_opts( - compression=compression, dateFormat=dateFormat, timestampFormat=timestampFormat, - timeZone=timeZone) + compression=compression, dateFormat=dateFormat, timestampFormat=timestampFormat) self._jwrite.json(path) @since(1.4) @@ -696,7 +708,7 @@ def text(self, path, compression=None): @since(2.0) def csv(self, path, mode=None, compression=None, sep=None, quote=None, escape=None, header=None, nullValue=None, escapeQuotes=None, quoteAll=None, dateFormat=None, - timestampFormat=None, timeZone=None): + timestampFormat=None): """Saves the content of the :class:`DataFrame` in CSV format at the specified path. :param path: the path in any Hadoop supported file system @@ -736,15 +748,13 @@ def csv(self, path, mode=None, compression=None, sep=None, quote=None, escape=No formats follow the formats at ``java.text.SimpleDateFormat``. This applies to timestamp type. If None is set, it uses the default value, ``yyyy-MM-dd'T'HH:mm:ss.SSSZZ``. - :param timeZone: sets the string that indicates a timezone to be used to parse timestamps. - If None is set, it uses the default value, session local timezone. >>> df.write.csv(os.path.join(tempfile.mkdtemp(), 'data')) """ self.mode(mode) self._set_opts(compression=compression, sep=sep, quote=quote, escape=escape, header=header, nullValue=nullValue, escapeQuotes=escapeQuotes, quoteAll=quoteAll, - dateFormat=dateFormat, timestampFormat=timestampFormat, timeZone=timeZone) + dateFormat=dateFormat, timestampFormat=timestampFormat) self._jwrite.csv(path) @since(1.5) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala index e3631b0c0773..b862deaf3636 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala @@ -113,7 +113,8 @@ case class CatalogTablePartition( */ def toRow(partitionSchema: StructType, defaultTimeZondId: String): InternalRow = { val caseInsensitiveProperties = CaseInsensitiveMap(storage.properties) - val timeZoneId = caseInsensitiveProperties.getOrElse("timeZone", defaultTimeZondId) + val timeZoneId = caseInsensitiveProperties.getOrElse( + DateTimeUtils.TIMEZONE_OPTION, defaultTimeZondId) InternalRow.fromSeq(partitionSchema.map { field => Cast(Literal(spec(field.name)), field.dataType, Option(timeZoneId)).eval() }) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala index 5a91f9c1939a..5f222ec602c9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala @@ -23,7 +23,7 @@ import com.fasterxml.jackson.core.{JsonFactory, JsonParser} import org.apache.commons.lang3.time.FastDateFormat import org.apache.spark.internal.Logging -import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, CompressionCodecs, ParseModes} +import org.apache.spark.sql.catalyst.util._ /** * Options for parsing JSON data into Spark SQL rows. @@ -69,7 +69,8 @@ private[sql] class JSONOptions( val columnNameOfCorruptRecord = parameters.getOrElse("columnNameOfCorruptRecord", defaultColumnNameOfCorruptRecord) - val timeZone: TimeZone = TimeZone.getTimeZone(parameters.getOrElse("timeZone", defaultTimeZoneId)) + val timeZone: TimeZone = TimeZone.getTimeZone( + parameters.getOrElse(DateTimeUtils.TIMEZONE_OPTION, defaultTimeZoneId)) // Uses `FastDateFormat` which can be direct replacement for `SimpleDateFormat` and thread-safe. val dateFormat: FastDateFormat = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index 9e1de0fd2f3d..9b94c1e2b40b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -60,6 +60,8 @@ object DateTimeUtils { final val TimeZoneGMT = TimeZone.getTimeZone("GMT") final val MonthOf31Days = Set(1, 3, 5, 7, 8, 10, 12) + val TIMEZONE_OPTION = "timeZone" + def defaultTimeZone(): TimeZone = TimeZone.getDefault() // Reuse the Calendar object in each thread as it is expensive to create in each method call. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala index e3584909ddc4..19d0c8eb92f1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala @@ -471,7 +471,8 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation( JsonToStruct( schema, - Map("timestampFormat" -> "yyyy-MM-dd'T'HH:mm:ss", "timeZone" -> tz.getID), + Map("timestampFormat" -> "yyyy-MM-dd'T'HH:mm:ss", + DateTimeUtils.TIMEZONE_OPTION -> tz.getID), Literal(jsonData2), gmtId), InternalRow(c.getTimeInMillis * 1000L) @@ -523,14 +524,16 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation( StructToJson( - Map("timestampFormat" -> "yyyy-MM-dd'T'HH:mm:ss", "timeZone" -> gmtId.get), + Map("timestampFormat" -> "yyyy-MM-dd'T'HH:mm:ss", + DateTimeUtils.TIMEZONE_OPTION -> gmtId.get), struct, gmtId), """{"t":"2016-01-01T00:00:00"}""" ) checkEvaluation( StructToJson( - Map("timestampFormat" -> "yyyy-MM-dd'T'HH:mm:ss", "timeZone" -> "PST"), + Map("timestampFormat" -> "yyyy-MM-dd'T'HH:mm:ss", + DateTimeUtils.TIMEZONE_OPTION -> "PST"), struct, gmtId), """{"t":"2015-12-31T16:00:00"}""" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 4f4cc9311749..f1bce1aa4102 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -70,6 +70,12 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { /** * Adds an input option for the underlying data source. * + * You can set the following option(s): + *
    + *
  • `timeZone` (default session local timezone): sets the string that indicates a timezone + * to be used to parse timestamps in the JSON/CSV datasources or parttion values.
  • + *
+ * * @since 1.4.0 */ def option(key: String, value: String): DataFrameReader = { @@ -101,6 +107,12 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { /** * (Scala-specific) Adds input options for the underlying data source. * + * You can set the following option(s): + *
    + *
  • `timeZone` (default session local timezone): sets the string that indicates a timezone + * to be used to parse timestamps in the JSON/CSV datasources or parttion values.
  • + *
+ * * @since 1.4.0 */ def options(options: scala.collection.Map[String, String]): DataFrameReader = { @@ -111,6 +123,12 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { /** * Adds input options for the underlying data source. * + * You can set the following option(s): + *
    + *
  • `timeZone` (default session local timezone): sets the string that indicates a timezone + * to be used to parse timestamps in the JSON/CSV datasources or parttion values.
  • + *
+ * * @since 1.4.0 */ def options(options: java.util.Map[String, String]): DataFrameReader = { @@ -305,8 +323,6 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { *
  • `timestampFormat` (default `yyyy-MM-dd'T'HH:mm:ss.SSSZZ`): sets the string that * indicates a timestamp format. Custom date formats follow the formats at * `java.text.SimpleDateFormat`. This applies to timestamp type.
  • - *
  • `timeZone` (default session local timezone): sets the string that indicates a timezone - * to be used to parse timestamps.
  • *
  • `wholeFile` (default `false`): parse one record, which may span multiple lines, * per file
  • * @@ -478,8 +494,6 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { *
  • `timestampFormat` (default `yyyy-MM-dd'T'HH:mm:ss.SSSZZ`): sets the string that * indicates a timestamp format. Custom date formats follow the formats at * `java.text.SimpleDateFormat`. This applies to timestamp type.
  • - *
  • `timeZone` (default session local timezone): sets the string that indicates a timezone - * to be used to parse timestamps.
  • *
  • `maxColumns` (default `20480`): defines a hard limit of how many columns * a record can have.
  • *
  • `maxCharsPerColumn` (default `-1`): defines the maximum number of characters allowed diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 49e85dc7b13f..608160a214fb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -90,6 +90,12 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { /** * Adds an output option for the underlying data source. * + * You can set the following option(s): + *
      + *
    • `timeZone` (default session local timezone): sets the string that indicates a timezone + * to be used to format timestamps in the JSON/CSV datasources or parttion values.
    • + *
    + * * @since 1.4.0 */ def option(key: String, value: String): DataFrameWriter[T] = { @@ -121,6 +127,12 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { /** * (Scala-specific) Adds output options for the underlying data source. * + * You can set the following option(s): + *
      + *
    • `timeZone` (default session local timezone): sets the string that indicates a timezone + * to be used to format timestamps in the JSON/CSV datasources or parttion values.
    • + *
    + * * @since 1.4.0 */ def options(options: scala.collection.Map[String, String]): DataFrameWriter[T] = { @@ -131,6 +143,12 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { /** * Adds output options for the underlying data source. * + * You can set the following option(s): + *
      + *
    • `timeZone` (default session local timezone): sets the string that indicates a timezone + * to be used to format timestamps in the JSON/CSV datasources or parttion values.
    • + *
    + * * @since 1.4.0 */ def options(options: java.util.Map[String, String]): DataFrameWriter[T] = { @@ -457,8 +475,6 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { *
  • `timestampFormat` (default `yyyy-MM-dd'T'HH:mm:ss.SSSZZ`): sets the string that * indicates a timestamp format. Custom date formats follow the formats at * `java.text.SimpleDateFormat`. This applies to timestamp type.
  • - *
  • `timeZone` (default session local timezone): sets the string that indicates a timezone - * to be used to format timestamps.
  • * * * @since 1.4.0 @@ -565,8 +581,6 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { *
  • `timestampFormat` (default `yyyy-MM-dd'T'HH:mm:ss.SSSZZ`): sets the string that * indicates a timestamp format. Custom date formats follow the formats at * `java.text.SimpleDateFormat`. This applies to timestamp type.
  • - *
  • `timeZone` (default session local timezone): sets the string that indicates a timezone - * to be used to format timestamps.
  • * * * @since 2.0.0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala index aa578f4d2313..769deb1890b6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala @@ -105,7 +105,7 @@ case class OptimizeMetadataOnlyQuery( val partAttrs = getPartitionAttrs(relation.tableMeta.partitionColumnNames, relation) val caseInsensitiveProperties = CaseInsensitiveMap(relation.tableMeta.storage.properties) - val timeZoneId = caseInsensitiveProperties.get("timeZone") + val timeZoneId = caseInsensitiveProperties.get(DateTimeUtils.TIMEZONE_OPTION) .getOrElse(conf.sessionLocalTimeZone) val partitionData = catalog.listPartitions(relation.tableMeta.identifier).map { p => InternalRow.fromSeq(partAttrs.map { attr => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala index 30a09a9ad337..ce33298aeb1d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala @@ -141,7 +141,7 @@ object FileFormatWriter extends Logging { customPartitionLocations = outputSpec.customPartitionLocations, maxRecordsPerFile = caseInsensitiveOptions.get("maxRecordsPerFile").map(_.toLong) .getOrElse(sparkSession.sessionState.conf.maxRecordsPerFile), - timeZoneId = caseInsensitiveOptions.get("timeZone") + timeZoneId = caseInsensitiveOptions.get(DateTimeUtils.TIMEZONE_OPTION) .getOrElse(sparkSession.sessionState.conf.sessionLocalTimeZone) ) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala index c8097a7fabc2..a5fa8b3f9385 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala @@ -127,7 +127,7 @@ abstract class PartitioningAwareFileIndex( }.keys.toSeq val caseInsensitiveOptions = CaseInsensitiveMap(parameters) - val timeZoneId = caseInsensitiveOptions.get("timeZone") + val timeZoneId = caseInsensitiveOptions.get(DateTimeUtils.TIMEZONE_OPTION) .getOrElse(sparkSession.sessionState.conf.sessionLocalTimeZone) userPartitionSchema match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala index 0b1e5dac2da6..2632e87971d6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala @@ -24,7 +24,7 @@ import com.univocity.parsers.csv.{CsvParserSettings, CsvWriterSettings, Unescape import org.apache.commons.lang3.time.FastDateFormat import org.apache.spark.internal.Logging -import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, CompressionCodecs, ParseModes} +import org.apache.spark.sql.catalyst.util._ class CSVOptions( @transient private val parameters: CaseInsensitiveMap[String], @@ -120,7 +120,8 @@ class CSVOptions( name.map(CompressionCodecs.getCodecClassName) } - val timeZone: TimeZone = TimeZone.getTimeZone(parameters.getOrElse("timeZone", defaultTimeZoneId)) + val timeZone: TimeZone = TimeZone.getTimeZone( + parameters.getOrElse(DateTimeUtils.TIMEZONE_OPTION, defaultTimeZoneId)) // Uses `FastDateFormat` which can be direct replacement for `SimpleDateFormat` and thread-safe. val dateFormat: FastDateFormat = diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index 4435e4df38ef..95dfdf5b298e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -29,6 +29,7 @@ import org.apache.hadoop.io.SequenceFile.CompressionType import org.apache.spark.SparkException import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row, UDT} +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.functions.{col, regexp_replace} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} @@ -912,7 +913,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { .format("csv") .option("header", "true") .option("timestampFormat", "yyyy/MM/dd HH:mm") - .option("timeZone", "GMT") + .option(DateTimeUtils.TIMEZONE_OPTION, "GMT") .save(timestampsWithFormatPath) // This will load back the timestamps as string. @@ -934,7 +935,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { .option("header", "true") .option("inferSchema", "true") .option("timestampFormat", "yyyy/MM/dd HH:mm") - .option("timeZone", "GMT") + .option(DateTimeUtils.TIMEZONE_OPTION, "GMT") .load(timestampsWithFormatPath) checkAnswer(readBack, timestampsWithFormat) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 0aaf148dac25..9b0efcbdaf5c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -1767,7 +1767,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { timestampsWithFormat.write .format("json") .option("timestampFormat", "yyyy/MM/dd HH:mm") - .option("timeZone", "GMT") + .option(DateTimeUtils.TIMEZONE_OPTION, "GMT") .save(timestampsWithFormatPath) // This will load back the timestamps as string. @@ -1785,7 +1785,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { val readBack = spark.read .schema(customSchema) .option("timestampFormat", "yyyy/MM/dd HH:mm") - .option("timeZone", "GMT") + .option(DateTimeUtils.TIMEZONE_OPTION, "GMT") .json(timestampsWithFormatPath) checkAnswer(readBack, timestampsWithFormat) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala index 88cb8a0bad21..2b20b9716bf8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils import org.apache.spark.sql.catalyst.expressions.Literal +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.{PartitionPath => Partition} import org.apache.spark.sql.functions._ @@ -708,10 +709,11 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha } withTempPath { dir => - df.write.option("timeZone", "GMT") + df.write.option(DateTimeUtils.TIMEZONE_OPTION, "GMT") .format("parquet").partitionBy(partitionColumns.map(_.name): _*).save(dir.toString) val fields = schema.map(f => Column(f.name).cast(f.dataType)) - checkAnswer(spark.read.option("timeZone", "GMT").load(dir.toString).select(fields: _*), row) + checkAnswer(spark.read.option(DateTimeUtils.TIMEZONE_OPTION, "GMT") + .load(dir.toString).select(fields: _*), row) } } @@ -749,10 +751,11 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha } withTempPath { dir => - df.write.option("timeZone", "GMT") + df.write.option(DateTimeUtils.TIMEZONE_OPTION, "GMT") .format("parquet").partitionBy(partitionColumns.map(_.name): _*).save(dir.toString) val fields = schema.map(f => Column(f.name)) - checkAnswer(spark.read.option("timeZone", "GMT").load(dir.toString).select(fields: _*), row) + checkAnswer(spark.read.option(DateTimeUtils.TIMEZONE_OPTION, "GMT") + .load(dir.toString).select(fields: _*), row) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala index f251290583c5..a2f3afe3ce23 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala @@ -25,6 +25,7 @@ import org.apache.hadoop.mapreduce.TaskAttemptContext import org.apache.spark.internal.Logging import org.apache.spark.sql.{QueryTest, Row} import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.datasources.SQLHadoopMapReduceCommitProtocol import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf @@ -142,7 +143,8 @@ class PartitionedWriteSuite extends QueryTest with SharedSQLContext { checkPartitionValues(files.head, "2016-12-01 00:00:00") } withTempPath { f => - df.write.option("timeZone", "GMT").partitionBy("ts").parquet(f.getAbsolutePath) + df.write.option(DateTimeUtils.TIMEZONE_OPTION, "GMT") + .partitionBy("ts").parquet(f.getAbsolutePath) val files = recursiveList(f).filter(_.getAbsolutePath.endsWith("parquet")) assert(files.length == 1) // use timeZone option "GMT" to format partition value. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala index 9b5e364e512a..0f97fd78d2ff 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala @@ -27,7 +27,8 @@ class ResolvedDataSourceSuite extends SparkFunSuite { DataSource( sparkSession = null, className = name, - options = Map("timeZone" -> DateTimeUtils.defaultTimeZone().getID)).providingClass + options = Map(DateTimeUtils.TIMEZONE_OPTION -> DateTimeUtils.defaultTimeZone().getID) + ).providingClass test("jdbc") { assert( diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index 33802ae62333..8860b7dc079c 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -39,7 +39,7 @@ import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils.escapePathName import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.ColumnStat -import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} +import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.PartitioningUtils import org.apache.spark.sql.hive.client.HiveClient From dacc382f0c918f1ca808228484305ce0e21c705e Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 15 Mar 2017 08:24:41 +0800 Subject: [PATCH 033/512] [SPARK-19887][SQL] dynamic partition keys can be null or empty string ## What changes were proposed in this pull request? When dynamic partition value is null or empty string, we should write the data to a directory like `a=__HIVE_DEFAULT_PARTITION__`, when we read the data back, we should respect this special directory name and treat it as null. This is the same behavior of impala, see https://issues.apache.org/jira/browse/IMPALA-252 ## How was this patch tested? new regression test Author: Wenchen Fan Closes #17277 from cloud-fan/partition. --- .../catalog/ExternalCatalogUtils.scala | 2 +- .../sql/catalyst/catalog/interface.scala | 9 +++++-- .../sql/execution/DataSourceScanExec.scala | 2 +- .../datasources/FileFormatWriter.scala | 11 ++++----- .../datasources/PartitioningUtils.scala | 3 +-- .../spark/sql/hive/HiveExternalCatalog.scala | 4 ++-- .../PartitionProviderCompatibilitySuite.scala | 24 ++++++++++++++++++- 7 files changed, 39 insertions(+), 16 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala index a418edc302d9..a8693dcca539 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala @@ -118,7 +118,7 @@ object ExternalCatalogUtils { } def getPartitionPathString(col: String, value: String): String = { - val partitionString = if (value == null) { + val partitionString = if (value == null || value.isEmpty) { DEFAULT_PARTITION_NAME } else { escapePathName(value) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala index b862deaf3636..70ed44e025f5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala @@ -116,7 +116,12 @@ case class CatalogTablePartition( val timeZoneId = caseInsensitiveProperties.getOrElse( DateTimeUtils.TIMEZONE_OPTION, defaultTimeZondId) InternalRow.fromSeq(partitionSchema.map { field => - Cast(Literal(spec(field.name)), field.dataType, Option(timeZoneId)).eval() + val partValue = if (spec(field.name) == ExternalCatalogUtils.DEFAULT_PARTITION_NAME) { + null + } else { + spec(field.name) + } + Cast(Literal(partValue), field.dataType, Option(timeZoneId)).eval() }) } } @@ -164,7 +169,7 @@ case class BucketSpec( * @param tracksPartitionsInCatalog whether this table's partition metadata is stored in the * catalog. If false, it is inferred automatically based on file * structure. - * @param schemaPresevesCase Whether or not the schema resolved for this table is case-sensitive. + * @param schemaPreservesCase Whether or not the schema resolved for this table is case-sensitive. * When using a Hive Metastore, this flag is set to false if a case- * sensitive schema was unable to be read from the table properties. * Used to trigger case-sensitive schema inference at query time, when diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index 39b010efec7b..8ebad676ca31 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -319,7 +319,7 @@ case class FileSourceScanExec( val input = ctx.freshName("input") ctx.addMutableState("scala.collection.Iterator", input, s"$input = inputs[0];") val exprRows = output.zipWithIndex.map{ case (a, i) => - new BoundReference(i, a.dataType, a.nullable) + BoundReference(i, a.dataType, a.nullable) } val row = ctx.freshName("row") ctx.INPUT_ROW = row diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala index ce33298aeb1d..7957224ce48b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala @@ -335,14 +335,11 @@ object FileFormatWriter extends Logging { /** Expressions that given partition columns build a path string like: col1=val/col2=val/... */ private def partitionPathExpression: Seq[Expression] = { desc.partitionColumns.zipWithIndex.flatMap { case (c, i) => - val escaped = ScalaUDF( - ExternalCatalogUtils.escapePathName _, + val partitionName = ScalaUDF( + ExternalCatalogUtils.getPartitionPathString _, StringType, - Seq(Cast(c, StringType, Option(desc.timeZoneId))), - Seq(StringType)) - val str = If(IsNull(c), Literal(ExternalCatalogUtils.DEFAULT_PARTITION_NAME), escaped) - val partitionName = Literal(ExternalCatalogUtils.escapePathName(c.name) + "=") :: str :: Nil - if (i == 0) partitionName else Literal(Path.SEPARATOR) :: partitionName + Seq(Literal(c.name), Cast(c, StringType, Option(desc.timeZoneId)))) + if (i == 0) Seq(partitionName) else Seq(Literal(Path.SEPARATOR), partitionName) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala index 09876bbc2f85..03980922ab38 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala @@ -33,7 +33,6 @@ import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.{Cast, Literal} import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String // TODO: We should tighten up visibility of the classes here once we clean up Hive coupling. @@ -129,7 +128,7 @@ object PartitioningUtils { // "hdfs://host:9000/invalidPath" // "hdfs://host:9000/path" // TODO: Selective case sensitivity. - val discoveredBasePaths = optDiscoveredBasePaths.flatMap(x => x).map(_.toString.toLowerCase()) + val discoveredBasePaths = optDiscoveredBasePaths.flatten.map(_.toString.toLowerCase()) assert( discoveredBasePaths.distinct.size == 1, "Conflicting directory structures detected. Suspicious paths:\b" + diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index 8860b7dc079c..8a3c81ac8b0f 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -1012,8 +1012,8 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat val partColNameMap = buildLowerCasePartColNameMap(catalogTable).mapValues(escapePathName) val clientPartitionNames = client.getPartitionNames(catalogTable, partialSpec.map(lowerCasePartitionSpec)) - clientPartitionNames.map { partName => - val partSpec = PartitioningUtils.parsePathFragmentAsSeq(partName) + clientPartitionNames.map { partitionPath => + val partSpec = PartitioningUtils.parsePathFragmentAsSeq(partitionPath) partSpec.map { case (partName, partValue) => partColNameMap(partName.toLowerCase) + "=" + escapePathName(partValue) }.mkString("/") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionProviderCompatibilitySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionProviderCompatibilitySuite.scala index 96385961c9a5..9440a17677eb 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionProviderCompatibilitySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionProviderCompatibilitySuite.scala @@ -22,7 +22,7 @@ import java.io.File import org.apache.hadoop.fs.Path import org.apache.spark.metrics.source.HiveCatalogMetrics -import org.apache.spark.sql.{AnalysisException, QueryTest} +import org.apache.spark.sql.{AnalysisException, QueryTest, Row} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.internal.SQLConf @@ -316,6 +316,28 @@ class PartitionProviderCompatibilitySuite } } } + + test(s"SPARK-19887 partition value is null - partition management $enabled") { + withTable("test") { + Seq((1, "p", 1), (2, null, 2)).toDF("a", "b", "c") + .write.partitionBy("b", "c").saveAsTable("test") + checkAnswer(spark.table("test"), + Row(1, "p", 1) :: Row(2, null, 2) :: Nil) + + Seq((3, null: String, 3)).toDF("a", "b", "c") + .write.mode("append").partitionBy("b", "c").saveAsTable("test") + checkAnswer(spark.table("test"), + Row(1, "p", 1) :: Row(2, null, 2) :: Row(3, null, 3) :: Nil) + // make sure partition pruning also works. + checkAnswer(spark.table("test").filter($"b".isNotNull), Row(1, "p", 1)) + + // empty string is an invalid partition value and we treat it as null when read back. + Seq((4, "", 4)).toDF("a", "b", "c") + .write.mode("append").partitionBy("b", "c").saveAsTable("test") + checkAnswer(spark.table("test"), + Row(1, "p", 1) :: Row(2, null, 2) :: Row(3, null, 3) :: Row(4, null, 4) :: Nil) + } + } } /** From 8fb2a02e2ce6832e3d9338a7d0148dfac9fa24c2 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Wed, 15 Mar 2017 10:19:19 +0800 Subject: [PATCH 034/512] [SPARK-19918][SQL] Use TextFileFormat in implementation of TextInputJsonDataSource ## What changes were proposed in this pull request? This PR proposes to use text datasource when Json schema inference. This basically proposes the similar approach in https://github.com/apache/spark/pull/15813 If we use Dataset for initial loading when inferring the schema, there are advantages. Please refer SPARK-18362 It seems JSON one was supposed to be fixed together but taken out according to https://github.com/apache/spark/pull/15813 > A similar problem also affects the JSON file format and this patch originally fixed that as well, but I've decided to split that change into a separate patch so as not to conflict with changes in another JSON PR. Also, this seems affecting some functionalities because it does not use `FileScanRDD`. This problem is described in SPARK-19885 (but it was CSV's case). ## How was this patch tested? Existing tests should cover this and manual test by `spark.read.json(path)` and check the UI. Author: hyukjinkwon Closes #17255 from HyukjinKwon/json-filescanrdd. --- .../apache/spark/sql/DataFrameReader.scala | 9 +- .../datasources/json/JsonDataSource.scala | 145 ++++++++---------- .../datasources/json/JsonFileFormat.scala | 2 +- .../datasources/json/JsonInferSchema.scala | 9 +- .../datasources/json/JsonUtils.scala | 51 ++++++ 5 files changed, 122 insertions(+), 94 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonUtils.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index f1bce1aa4102..309654c80414 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.csv._ import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.datasources.jdbc._ -import org.apache.spark.sql.execution.datasources.json.JsonInferSchema +import org.apache.spark.sql.execution.datasources.json.TextInputJsonDataSource import org.apache.spark.sql.types.{StringType, StructType} import org.apache.spark.unsafe.types.UTF8String @@ -376,17 +376,14 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { extraOptions.toMap, sparkSession.sessionState.conf.sessionLocalTimeZone, sparkSession.sessionState.conf.columnNameOfCorruptRecord) - val createParser = CreateJacksonParser.string _ val schema = userSpecifiedSchema.getOrElse { - JsonInferSchema.infer( - jsonDataset.rdd, - parsedOptions, - createParser) + TextInputJsonDataSource.inferFromDataset(jsonDataset, parsedOptions) } verifyColumnNameOfCorruptRecord(schema, parsedOptions.columnNameOfCorruptRecord) + val createParser = CreateJacksonParser.string _ val parsed = jsonDataset.rdd.mapPartitions { iter => val parser = new JacksonParser(schema, parsedOptions) iter.flatMap(parser.parse(_, createParser, UTF8String.fromString)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala index 18843bfc307b..84f026620d90 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala @@ -17,32 +17,30 @@ package org.apache.spark.sql.execution.datasources.json -import scala.reflect.ClassTag - import com.fasterxml.jackson.core.{JsonFactory, JsonParser} import com.google.common.io.ByteStreams import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.FileStatus -import org.apache.hadoop.io.{LongWritable, Text} +import org.apache.hadoop.io.Text import org.apache.hadoop.mapreduce.Job -import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat, TextInputFormat} +import org.apache.hadoop.mapreduce.lib.input.FileInputFormat import org.apache.spark.TaskContext import org.apache.spark.input.{PortableDataStream, StreamInputFormat} import org.apache.spark.rdd.{BinaryFileRDD, RDD} -import org.apache.spark.sql.{AnalysisException, SparkSession} +import org.apache.spark.sql.{AnalysisException, Dataset, Encoders, SparkSession} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JSONOptions} -import org.apache.spark.sql.execution.datasources.{CodecStreams, HadoopFileLinesReader, PartitionedFile} +import org.apache.spark.sql.execution.datasources.{CodecStreams, DataSource, HadoopFileLinesReader, PartitionedFile} +import org.apache.spark.sql.execution.datasources.text.TextFileFormat import org.apache.spark.sql.types.StructType import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils /** * Common functions for parsing JSON files - * @tparam T A datatype containing the unparsed JSON, such as [[Text]] or [[String]] */ -abstract class JsonDataSource[T] extends Serializable { +abstract class JsonDataSource extends Serializable { def isSplitable: Boolean /** @@ -53,28 +51,12 @@ abstract class JsonDataSource[T] extends Serializable { file: PartitionedFile, parser: JacksonParser): Iterator[InternalRow] - /** - * Create an [[RDD]] that handles the preliminary parsing of [[T]] records - */ - protected def createBaseRdd( - sparkSession: SparkSession, - inputPaths: Seq[FileStatus]): RDD[T] - - /** - * A generic wrapper to invoke the correct [[JsonFactory]] method to allocate a [[JsonParser]] - * for an instance of [[T]] - */ - def createParser(jsonFactory: JsonFactory, value: T): JsonParser - - final def infer( + final def inferSchema( sparkSession: SparkSession, inputPaths: Seq[FileStatus], parsedOptions: JSONOptions): Option[StructType] = { if (inputPaths.nonEmpty) { - val jsonSchema = JsonInferSchema.infer( - createBaseRdd(sparkSession, inputPaths), - parsedOptions, - createParser) + val jsonSchema = infer(sparkSession, inputPaths, parsedOptions) checkConstraints(jsonSchema) Some(jsonSchema) } else { @@ -82,6 +64,11 @@ abstract class JsonDataSource[T] extends Serializable { } } + protected def infer( + sparkSession: SparkSession, + inputPaths: Seq[FileStatus], + parsedOptions: JSONOptions): StructType + /** Constraints to be imposed on schema to be stored. */ private def checkConstraints(schema: StructType): Unit = { if (schema.fieldNames.length != schema.fieldNames.distinct.length) { @@ -95,53 +82,46 @@ abstract class JsonDataSource[T] extends Serializable { } object JsonDataSource { - def apply(options: JSONOptions): JsonDataSource[_] = { + def apply(options: JSONOptions): JsonDataSource = { if (options.wholeFile) { WholeFileJsonDataSource } else { TextInputJsonDataSource } } - - /** - * Create a new [[RDD]] via the supplied callback if there is at least one file to process, - * otherwise an [[org.apache.spark.rdd.EmptyRDD]] will be returned. - */ - def createBaseRdd[T : ClassTag]( - sparkSession: SparkSession, - inputPaths: Seq[FileStatus])( - fn: (Configuration, String) => RDD[T]): RDD[T] = { - val paths = inputPaths.map(_.getPath) - - if (paths.nonEmpty) { - val job = Job.getInstance(sparkSession.sessionState.newHadoopConf()) - FileInputFormat.setInputPaths(job, paths: _*) - fn(job.getConfiguration, paths.mkString(",")) - } else { - sparkSession.sparkContext.emptyRDD[T] - } - } } -object TextInputJsonDataSource extends JsonDataSource[Text] { +object TextInputJsonDataSource extends JsonDataSource { override val isSplitable: Boolean = { // splittable if the underlying source is true } - override protected def createBaseRdd( + override def infer( sparkSession: SparkSession, - inputPaths: Seq[FileStatus]): RDD[Text] = { - JsonDataSource.createBaseRdd(sparkSession, inputPaths) { - case (conf, name) => - sparkSession.sparkContext.newAPIHadoopRDD( - conf, - classOf[TextInputFormat], - classOf[LongWritable], - classOf[Text]) - .setName(s"JsonLines: $name") - .values // get the text column - } + inputPaths: Seq[FileStatus], + parsedOptions: JSONOptions): StructType = { + val json: Dataset[String] = createBaseDataset(sparkSession, inputPaths) + inferFromDataset(json, parsedOptions) + } + + def inferFromDataset(json: Dataset[String], parsedOptions: JSONOptions): StructType = { + val sampled: Dataset[String] = JsonUtils.sample(json, parsedOptions) + val rdd: RDD[UTF8String] = sampled.queryExecution.toRdd.map(_.getUTF8String(0)) + JsonInferSchema.infer(rdd, parsedOptions, CreateJacksonParser.utf8String) + } + + private def createBaseDataset( + sparkSession: SparkSession, + inputPaths: Seq[FileStatus]): Dataset[String] = { + val paths = inputPaths.map(_.getPath.toString) + sparkSession.baseRelationToDataFrame( + DataSource.apply( + sparkSession, + paths = paths, + className = classOf[TextFileFormat].getName + ).resolveRelation(checkFilesExist = false)) + .select("value").as(Encoders.STRING) } override def readFile( @@ -150,41 +130,48 @@ object TextInputJsonDataSource extends JsonDataSource[Text] { parser: JacksonParser): Iterator[InternalRow] = { val linesReader = new HadoopFileLinesReader(file, conf) Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => linesReader.close())) - linesReader.flatMap(parser.parse(_, createParser, textToUTF8String)) + linesReader.flatMap(parser.parse(_, CreateJacksonParser.text, textToUTF8String)) } private def textToUTF8String(value: Text): UTF8String = { UTF8String.fromBytes(value.getBytes, 0, value.getLength) } - - override def createParser(jsonFactory: JsonFactory, value: Text): JsonParser = { - CreateJacksonParser.text(jsonFactory, value) - } } -object WholeFileJsonDataSource extends JsonDataSource[PortableDataStream] { +object WholeFileJsonDataSource extends JsonDataSource { override val isSplitable: Boolean = { false } - override protected def createBaseRdd( + override def infer( + sparkSession: SparkSession, + inputPaths: Seq[FileStatus], + parsedOptions: JSONOptions): StructType = { + val json: RDD[PortableDataStream] = createBaseRdd(sparkSession, inputPaths) + val sampled: RDD[PortableDataStream] = JsonUtils.sample(json, parsedOptions) + JsonInferSchema.infer(sampled, parsedOptions, createParser) + } + + private def createBaseRdd( sparkSession: SparkSession, inputPaths: Seq[FileStatus]): RDD[PortableDataStream] = { - JsonDataSource.createBaseRdd(sparkSession, inputPaths) { - case (conf, name) => - new BinaryFileRDD( - sparkSession.sparkContext, - classOf[StreamInputFormat], - classOf[String], - classOf[PortableDataStream], - conf, - sparkSession.sparkContext.defaultMinPartitions) - .setName(s"JsonFile: $name") - .values - } + val paths = inputPaths.map(_.getPath) + val job = Job.getInstance(sparkSession.sessionState.newHadoopConf()) + val conf = job.getConfiguration + val name = paths.mkString(",") + FileInputFormat.setInputPaths(job, paths: _*) + new BinaryFileRDD( + sparkSession.sparkContext, + classOf[StreamInputFormat], + classOf[String], + classOf[PortableDataStream], + conf, + sparkSession.sparkContext.defaultMinPartitions) + .setName(s"JsonFile: $name") + .values } - override def createParser(jsonFactory: JsonFactory, record: PortableDataStream): JsonParser = { + private def createParser(jsonFactory: JsonFactory, record: PortableDataStream): JsonParser = { CreateJacksonParser.inputStream( jsonFactory, CodecStreams.createInputStreamWithCloseResource(record.getConfiguration, record.getPath())) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala index 902fee5a7e3f..a9dd91eba6f7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala @@ -54,7 +54,7 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister { options, sparkSession.sessionState.conf.sessionLocalTimeZone, sparkSession.sessionState.conf.columnNameOfCorruptRecord) - JsonDataSource(parsedOptions).infer( + JsonDataSource(parsedOptions).inferSchema( sparkSession, files, parsedOptions) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala index ab09358115c0..7475f8ec7933 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala @@ -40,18 +40,11 @@ private[sql] object JsonInferSchema { json: RDD[T], configOptions: JSONOptions, createParser: (JsonFactory, T) => JsonParser): StructType = { - require(configOptions.samplingRatio > 0, - s"samplingRatio (${configOptions.samplingRatio}) should be greater than 0") val shouldHandleCorruptRecord = configOptions.permissive val columnNameOfCorruptRecord = configOptions.columnNameOfCorruptRecord - val schemaData = if (configOptions.samplingRatio > 0.99) { - json - } else { - json.sample(withReplacement = false, configOptions.samplingRatio, 1) - } // perform schema inference on each row and merge afterwards - val rootType = schemaData.mapPartitions { iter => + val rootType = json.mapPartitions { iter => val factory = new JsonFactory() configOptions.setJacksonOptions(factory) iter.flatMap { row => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonUtils.scala new file mode 100644 index 000000000000..d511594c5de1 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonUtils.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.json + +import org.apache.spark.input.PortableDataStream +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.Dataset +import org.apache.spark.sql.catalyst.json.JSONOptions + +object JsonUtils { + /** + * Sample JSON dataset as configured by `samplingRatio`. + */ + def sample(json: Dataset[String], options: JSONOptions): Dataset[String] = { + require(options.samplingRatio > 0, + s"samplingRatio (${options.samplingRatio}) should be greater than 0") + if (options.samplingRatio > 0.99) { + json + } else { + json.sample(withReplacement = false, options.samplingRatio, 1) + } + } + + /** + * Sample JSON RDD as configured by `samplingRatio`. + */ + def sample(json: RDD[PortableDataStream], options: JSONOptions): RDD[PortableDataStream] = { + require(options.samplingRatio > 0, + s"samplingRatio (${options.samplingRatio}) should be greater than 0") + if (options.samplingRatio > 0.99) { + json + } else { + json.sample(withReplacement = false, options.samplingRatio, 1) + } + } +} From d1f6c64c4b763c05d6d79ae5497f298dc3835f3e Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Tue, 14 Mar 2017 19:51:25 -0700 Subject: [PATCH 035/512] [SPARK-19828][R] Support array type in from_json in R ## What changes were proposed in this pull request? Since we could not directly define the array type in R, this PR proposes to support array types in R as string types that are used in `structField` as below: ```R jsonArr <- "[{\"name\":\"Bob\"}, {\"name\":\"Alice\"}]" df <- as.DataFrame(list(list("people" = jsonArr))) collect(select(df, alias(from_json(df$people, "array>"), "arrcol"))) ``` prints ```R arrcol 1 Bob, Alice ``` ## How was this patch tested? Unit tests in `test_sparkSQL.R`. Author: hyukjinkwon Closes #17178 from HyukjinKwon/SPARK-19828. --- R/pkg/R/functions.R | 12 ++++++++++-- R/pkg/inst/tests/testthat/test_sparkSQL.R | 12 ++++++++++++ .../scala/org/apache/spark/sql/api/r/SQLUtils.scala | 2 +- 3 files changed, 23 insertions(+), 3 deletions(-) diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index edf2bcf8fdb3..9867f2d5b7c5 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -2437,6 +2437,7 @@ setMethod("date_format", signature(y = "Column", x = "character"), #' #' @param x Column containing the JSON string. #' @param schema a structType object to use as the schema to use when parsing the JSON string. +#' @param asJsonArray indicating if input string is JSON array of objects or a single object. #' @param ... additional named properties to control how the json is parsed, accepts the same #' options as the JSON data source. #' @@ -2452,11 +2453,18 @@ setMethod("date_format", signature(y = "Column", x = "character"), #'} #' @note from_json since 2.2.0 setMethod("from_json", signature(x = "Column", schema = "structType"), - function(x, schema, ...) { + function(x, schema, asJsonArray = FALSE, ...) { + if (asJsonArray) { + jschema <- callJStatic("org.apache.spark.sql.types.DataTypes", + "createArrayType", + schema$jobj) + } else { + jschema <- schema$jobj + } options <- varargsToStrEnv(...) jc <- callJStatic("org.apache.spark.sql.functions", "from_json", - x@jc, schema$jobj, options) + x@jc, jschema, options) column(jc) }) diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index 9735fe320155..f7081cb1d4e5 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -1364,6 +1364,18 @@ test_that("column functions", { # check for unparseable df <- as.DataFrame(list(list("a" = ""))) expect_equal(collect(select(df, from_json(df$a, schema)))[[1]][[1]], NA) + + # check if array type in string is correctly supported. + jsonArr <- "[{\"name\":\"Bob\"}, {\"name\":\"Alice\"}]" + df <- as.DataFrame(list(list("people" = jsonArr))) + schema <- structType(structField("name", "string")) + arr <- collect(select(df, alias(from_json(df$people, schema, asJsonArray = TRUE), "arrcol"))) + expect_equal(ncol(arr), 1) + expect_equal(nrow(arr), 1) + expect_is(arr[[1]][[1]], "list") + expect_equal(length(arr$arrcol[[1]]), 2) + expect_equal(arr$arrcol[[1]][[1]]$name, "Bob") + expect_equal(arr$arrcol[[1]][[2]]$name, "Alice") }) test_that("column binary mathfunctions", { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala index a4c5bf756cd5..c77328690dae 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala @@ -81,7 +81,7 @@ private[sql] object SQLUtils extends Logging { new JavaSparkContext(spark.sparkContext) } - def createStructType(fields : Seq[StructField]): StructType = { + def createStructType(fields: Seq[StructField]): StructType = { StructType(fields) } From f9a93b1b4a20e7c72d900362b269edab66e73dd8 Mon Sep 17 00:00:00 2001 From: Xiao Li Date: Wed, 15 Mar 2017 10:53:58 +0800 Subject: [PATCH 036/512] [SPARK-18112][SQL] Support reading data from Hive 2.1 metastore ### What changes were proposed in this pull request? This PR is to support reading data from Hive 2.1 metastore. Need to update shim class because of the Hive API changes caused by the following three Hive JIRAs: - [HIVE-12730 MetadataUpdater: provide a mechanism to edit the basic statistics of a table (or a partition)](https://issues.apache.org/jira/browse/HIVE-12730) - [Hive-13341 Stats state is not captured correctly: differentiate load table and create table](https://issues.apache.org/jira/browse/HIVE-13341) - [HIVE-13622 WriteSet tracking optimizations](https://issues.apache.org/jira/browse/HIVE-13622) There are three new fields added to Hive APIs. - `boolean hasFollowingStatsTask`. We always set it to `false`. This is to keep the existing behavior unchanged (starting from 0.13), no matter which Hive metastore client version users choose. If we set it to `true`, the basic table statistics is not collected by Hive. For example, ```SQL CREATE TABLE tbl AS SELECT 1 AS a ``` When setting `hasFollowingStatsTask ` to `false`, the table properties is like ``` Properties: [numFiles=1, transient_lastDdlTime=1489513927, totalSize=2] ``` When setting `hasFollowingStatsTask ` to `true`, the table properties is like ``` Properties: [transient_lastDdlTime=1489513563] ``` - `AcidUtils.Operation operation`. Obviously, we do not support ACID. Thus, we set it to `AcidUtils.Operation.NOT_ACID`. - `EnvironmentContext environmentContext`. So far, this is always set to `null`. This was introduced for supporting DDL `alter table s update statistics set ('numRows'='NaN')`. Using this DDL, users can specify the statistics. So far, our Spark SQL does not need it, because we use different table properties to store our generated statistics values. However, when Spark SQL issues ALTER TABLE DDL statements, Hive metastore always automatically invalidate the Hive-generated statistics. In the follow-up PR, we can fix it by explicitly adding a property to `environmentContext`. ```JAVA putToProperties(StatsSetupConst.STATS_GENERATED, StatsSetupConst.USER) ``` Another alternative is to set `DO_NOT_UPDATE_STATS`to `TRUE`. See the Hive JIRA: https://issues.apache.org/jira/browse/HIVE-15653. We will not address it in this PR. ### How was this patch tested? Added test cases to VersionsSuite.scala Author: Xiao Li Closes #17232 from gatorsmile/Hive21. --- .../spark/sql/hive/HiveExternalCatalog.scala | 1 - .../sql/hive/client/HiveClientImpl.scala | 5 +- .../spark/sql/hive/client/HiveShim.scala | 181 ++++++++++++++++-- .../hive/client/IsolatedClientLoader.scala | 1 + .../spark/sql/hive/client/package.scala | 6 +- .../hive/execution/InsertIntoHiveTable.scala | 2 +- .../spark/sql/hive/client/VersionsSuite.scala | 19 +- 7 files changed, 190 insertions(+), 25 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index 8a3c81ac8b0f..33b21be37203 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.hive import java.io.IOException import java.lang.reflect.InvocationTargetException -import java.net.URI import java.util import scala.collection.mutable diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala index 6e1f429286cf..989fdc5564d3 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala @@ -97,6 +97,7 @@ private[hive] class HiveClientImpl( case hive.v1_1 => new Shim_v1_1() case hive.v1_2 => new Shim_v1_2() case hive.v2_0 => new Shim_v2_0() + case hive.v2_1 => new Shim_v2_1() } // Create an internal session state for this HiveClientImpl. @@ -455,7 +456,7 @@ private[hive] class HiveClientImpl( val hiveTable = toHiveTable(table, Some(conf)) // Do not use `table.qualifiedName` here because this may be a rename val qualifiedTableName = s"${table.database}.$tableName" - client.alterTable(qualifiedTableName, hiveTable) + shim.alterTable(client, qualifiedTableName, hiveTable) } override def createPartitions( @@ -535,7 +536,7 @@ private[hive] class HiveClientImpl( table: String, newParts: Seq[CatalogTablePartition]): Unit = withHiveState { val hiveTable = toHiveTable(getTable(db, table), Some(conf)) - client.alterPartitions(table, newParts.map { p => toHivePartition(p, hiveTable) }.asJava) + shim.alterPartitions(client, table, newParts.map { p => toHivePartition(p, hiveTable) }.asJava) } /** diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala index 153f1673c96f..76568f599078 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala @@ -28,8 +28,10 @@ import scala.util.control.NonFatal import org.apache.hadoop.fs.Path import org.apache.hadoop.hive.conf.HiveConf -import org.apache.hadoop.hive.metastore.api.{Function => HiveFunction, FunctionType, MetaException, PrincipalType, ResourceType, ResourceUri} +import org.apache.hadoop.hive.metastore.api.{EnvironmentContext, Function => HiveFunction, FunctionType} +import org.apache.hadoop.hive.metastore.api.{MetaException, PrincipalType, ResourceType, ResourceUri} import org.apache.hadoop.hive.ql.Driver +import org.apache.hadoop.hive.ql.io.AcidUtils import org.apache.hadoop.hive.ql.metadata.{Hive, HiveException, Partition, Table} import org.apache.hadoop.hive.ql.plan.AddPartitionDesc import org.apache.hadoop.hive.ql.processors.{CommandProcessor, CommandProcessorFactory} @@ -82,6 +84,10 @@ private[client] sealed abstract class Shim { def getMetastoreClientConnectRetryDelayMillis(conf: HiveConf): Long + def alterTable(hive: Hive, tableName: String, table: Table): Unit + + def alterPartitions(hive: Hive, tableName: String, newParts: JList[Partition]): Unit + def createPartitions( hive: Hive, db: String, @@ -158,6 +164,10 @@ private[client] sealed abstract class Shim { } private[client] class Shim_v0_12 extends Shim with Logging { + // See HIVE-12224, HOLD_DDLTIME was broken as soon as it landed + protected lazy val holdDDLTime = JBoolean.FALSE + // deletes the underlying data along with metadata + protected lazy val deleteDataInDropIndex = JBoolean.TRUE private lazy val startMethod = findStaticMethod( @@ -240,6 +250,18 @@ private[client] class Shim_v0_12 extends Shim with Logging { classOf[String], classOf[String], JBoolean.TYPE) + private lazy val alterTableMethod = + findMethod( + classOf[Hive], + "alterTable", + classOf[String], + classOf[Table]) + private lazy val alterPartitionsMethod = + findMethod( + classOf[Hive], + "alterPartitions", + classOf[String], + classOf[JList[Partition]]) override def setCurrentSessionState(state: SessionState): Unit = { // Starting from Hive 0.13, setCurrentSessionState will internally override @@ -341,7 +363,7 @@ private[client] class Shim_v0_12 extends Shim with Logging { tableName: String, replace: Boolean, isSrcLocal: Boolean): Unit = { - loadTableMethod.invoke(hive, loadPath, tableName, replace: JBoolean, JBoolean.FALSE) + loadTableMethod.invoke(hive, loadPath, tableName, replace: JBoolean, holdDDLTime) } override def loadDynamicPartitions( @@ -353,11 +375,11 @@ private[client] class Shim_v0_12 extends Shim with Logging { numDP: Int, listBucketingEnabled: Boolean): Unit = { loadDynamicPartitionsMethod.invoke(hive, loadPath, tableName, partSpec, replace: JBoolean, - numDP: JInteger, JBoolean.FALSE, listBucketingEnabled: JBoolean) + numDP: JInteger, holdDDLTime, listBucketingEnabled: JBoolean) } override def dropIndex(hive: Hive, dbName: String, tableName: String, indexName: String): Unit = { - dropIndexMethod.invoke(hive, dbName, tableName, indexName, true: JBoolean) + dropIndexMethod.invoke(hive, dbName, tableName, indexName, deleteDataInDropIndex) } override def dropTable( @@ -373,6 +395,14 @@ private[client] class Shim_v0_12 extends Shim with Logging { hive.dropTable(dbName, tableName, deleteData, ignoreIfNotExists) } + override def alterTable(hive: Hive, tableName: String, table: Table): Unit = { + alterTableMethod.invoke(hive, tableName, table) + } + + override def alterPartitions(hive: Hive, tableName: String, newParts: JList[Partition]): Unit = { + alterPartitionsMethod.invoke(hive, tableName, newParts) + } + override def dropPartition( hive: Hive, dbName: String, @@ -520,7 +550,7 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { } FunctionResource(FunctionResourceType.fromString(resourceType), uri.getUri()) } - new CatalogFunction(name, hf.getClassName, resources) + CatalogFunction(name, hf.getClassName, resources) } override def getFunctionOption(hive: Hive, db: String, name: String): Option[CatalogFunction] = { @@ -638,6 +668,11 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { private[client] class Shim_v0_14 extends Shim_v0_13 { + // true if this is an ACID operation + protected lazy val isAcid = JBoolean.FALSE + // true if list bucketing enabled + protected lazy val isSkewedStoreAsSubdir = JBoolean.FALSE + private lazy val loadPartitionMethod = findMethod( classOf[Hive], @@ -700,8 +735,8 @@ private[client] class Shim_v0_14 extends Shim_v0_13 { isSkewedStoreAsSubdir: Boolean, isSrcLocal: Boolean): Unit = { loadPartitionMethod.invoke(hive, loadPath, tableName, partSpec, replace: JBoolean, - JBoolean.FALSE, inheritTableSpecs: JBoolean, isSkewedStoreAsSubdir: JBoolean, - isSrcLocal: JBoolean, JBoolean.FALSE) + holdDDLTime, inheritTableSpecs: JBoolean, isSkewedStoreAsSubdir: JBoolean, + isSrcLocal: JBoolean, isAcid) } override def loadTable( @@ -710,8 +745,8 @@ private[client] class Shim_v0_14 extends Shim_v0_13 { tableName: String, replace: Boolean, isSrcLocal: Boolean): Unit = { - loadTableMethod.invoke(hive, loadPath, tableName, replace: JBoolean, JBoolean.FALSE, - isSrcLocal: JBoolean, JBoolean.FALSE, JBoolean.FALSE) + loadTableMethod.invoke(hive, loadPath, tableName, replace: JBoolean, holdDDLTime, + isSrcLocal: JBoolean, isSkewedStoreAsSubdir, isAcid) } override def loadDynamicPartitions( @@ -723,7 +758,7 @@ private[client] class Shim_v0_14 extends Shim_v0_13 { numDP: Int, listBucketingEnabled: Boolean): Unit = { loadDynamicPartitionsMethod.invoke(hive, loadPath, tableName, partSpec, replace: JBoolean, - numDP: JInteger, JBoolean.FALSE, listBucketingEnabled: JBoolean, JBoolean.FALSE) + numDP: JInteger, holdDDLTime, listBucketingEnabled: JBoolean, isAcid) } override def dropTable( @@ -752,6 +787,9 @@ private[client] class Shim_v1_0 extends Shim_v0_14 { private[client] class Shim_v1_1 extends Shim_v1_0 { + // throws an exception if the index does not exist + protected lazy val throwExceptionInDropIndex = JBoolean.TRUE + private lazy val dropIndexMethod = findMethod( classOf[Hive], @@ -763,13 +801,17 @@ private[client] class Shim_v1_1 extends Shim_v1_0 { JBoolean.TYPE) override def dropIndex(hive: Hive, dbName: String, tableName: String, indexName: String): Unit = { - dropIndexMethod.invoke(hive, dbName, tableName, indexName, true: JBoolean, true: JBoolean) + dropIndexMethod.invoke(hive, dbName, tableName, indexName, throwExceptionInDropIndex, + deleteDataInDropIndex) } } private[client] class Shim_v1_2 extends Shim_v1_1 { + // txnId can be 0 unless isAcid == true + protected lazy val txnIdInLoadDynamicPartitions: JLong = 0L + private lazy val loadDynamicPartitionsMethod = findMethod( classOf[Hive], @@ -806,8 +848,8 @@ private[client] class Shim_v1_2 extends Shim_v1_1 { numDP: Int, listBucketingEnabled: Boolean): Unit = { loadDynamicPartitionsMethod.invoke(hive, loadPath, tableName, partSpec, replace: JBoolean, - numDP: JInteger, JBoolean.FALSE, listBucketingEnabled: JBoolean, JBoolean.FALSE, - 0L: JLong) + numDP: JInteger, holdDDLTime, listBucketingEnabled: JBoolean, isAcid, + txnIdInLoadDynamicPartitions) } override def dropPartition( @@ -872,7 +914,106 @@ private[client] class Shim_v2_0 extends Shim_v1_2 { isSrcLocal: Boolean): Unit = { loadPartitionMethod.invoke(hive, loadPath, tableName, partSpec, replace: JBoolean, inheritTableSpecs: JBoolean, isSkewedStoreAsSubdir: JBoolean, - isSrcLocal: JBoolean, JBoolean.FALSE) + isSrcLocal: JBoolean, isAcid) + } + + override def loadTable( + hive: Hive, + loadPath: Path, + tableName: String, + replace: Boolean, + isSrcLocal: Boolean): Unit = { + loadTableMethod.invoke(hive, loadPath, tableName, replace: JBoolean, isSrcLocal: JBoolean, + isSkewedStoreAsSubdir, isAcid) + } + + override def loadDynamicPartitions( + hive: Hive, + loadPath: Path, + tableName: String, + partSpec: JMap[String, String], + replace: Boolean, + numDP: Int, + listBucketingEnabled: Boolean): Unit = { + loadDynamicPartitionsMethod.invoke(hive, loadPath, tableName, partSpec, replace: JBoolean, + numDP: JInteger, listBucketingEnabled: JBoolean, isAcid, txnIdInLoadDynamicPartitions) + } + +} + +private[client] class Shim_v2_1 extends Shim_v2_0 { + + // true if there is any following stats task + protected lazy val hasFollowingStatsTask = JBoolean.FALSE + // TODO: Now, always set environmentContext to null. In the future, we should avoid setting + // hive-generated stats to -1 when altering tables by using environmentContext. See Hive-12730 + protected lazy val environmentContextInAlterTable = null + + private lazy val loadPartitionMethod = + findMethod( + classOf[Hive], + "loadPartition", + classOf[Path], + classOf[String], + classOf[JMap[String, String]], + JBoolean.TYPE, + JBoolean.TYPE, + JBoolean.TYPE, + JBoolean.TYPE, + JBoolean.TYPE, + JBoolean.TYPE) + private lazy val loadTableMethod = + findMethod( + classOf[Hive], + "loadTable", + classOf[Path], + classOf[String], + JBoolean.TYPE, + JBoolean.TYPE, + JBoolean.TYPE, + JBoolean.TYPE, + JBoolean.TYPE) + private lazy val loadDynamicPartitionsMethod = + findMethod( + classOf[Hive], + "loadDynamicPartitions", + classOf[Path], + classOf[String], + classOf[JMap[String, String]], + JBoolean.TYPE, + JInteger.TYPE, + JBoolean.TYPE, + JBoolean.TYPE, + JLong.TYPE, + JBoolean.TYPE, + classOf[AcidUtils.Operation]) + private lazy val alterTableMethod = + findMethod( + classOf[Hive], + "alterTable", + classOf[String], + classOf[Table], + classOf[EnvironmentContext]) + private lazy val alterPartitionsMethod = + findMethod( + classOf[Hive], + "alterPartitions", + classOf[String], + classOf[JList[Partition]], + classOf[EnvironmentContext]) + + override def loadPartition( + hive: Hive, + loadPath: Path, + tableName: String, + partSpec: JMap[String, String], + replace: Boolean, + inheritTableSpecs: Boolean, + isSkewedStoreAsSubdir: Boolean, + isSrcLocal: Boolean): Unit = { + loadPartitionMethod.invoke(hive, loadPath, tableName, partSpec, replace: JBoolean, + inheritTableSpecs: JBoolean, isSkewedStoreAsSubdir: JBoolean, + isSrcLocal: JBoolean, isAcid, hasFollowingStatsTask) } override def loadTable( @@ -882,7 +1023,7 @@ private[client] class Shim_v2_0 extends Shim_v1_2 { replace: Boolean, isSrcLocal: Boolean): Unit = { loadTableMethod.invoke(hive, loadPath, tableName, replace: JBoolean, isSrcLocal: JBoolean, - JBoolean.FALSE, JBoolean.FALSE) + isSkewedStoreAsSubdir, isAcid, hasFollowingStatsTask) } override def loadDynamicPartitions( @@ -894,7 +1035,15 @@ private[client] class Shim_v2_0 extends Shim_v1_2 { numDP: Int, listBucketingEnabled: Boolean): Unit = { loadDynamicPartitionsMethod.invoke(hive, loadPath, tableName, partSpec, replace: JBoolean, - numDP: JInteger, listBucketingEnabled: JBoolean, JBoolean.FALSE, 0L: JLong) + numDP: JInteger, listBucketingEnabled: JBoolean, isAcid, txnIdInLoadDynamicPartitions, + hasFollowingStatsTask, AcidUtils.Operation.NOT_ACID) } + override def alterTable(hive: Hive, tableName: String, table: Table): Unit = { + alterTableMethod.invoke(hive, tableName, table, environmentContextInAlterTable) + } + + override def alterPartitions(hive: Hive, tableName: String, newParts: JList[Partition]): Unit = { + alterPartitionsMethod.invoke(hive, tableName, newParts, environmentContextInAlterTable) + } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala index 6f69a4adf29d..e95f9ea48043 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala @@ -95,6 +95,7 @@ private[hive] object IsolatedClientLoader extends Logging { case "1.1" | "1.1.0" => hive.v1_1 case "1.2" | "1.2.0" | "1.2.1" => hive.v1_2 case "2.0" | "2.0.0" | "2.0.1" => hive.v2_0 + case "2.1" | "2.1.0" | "2.1.1" => hive.v2_1 } private def downloadVersion( diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala index 790ad74e6639..f9635e36549e 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala @@ -67,7 +67,11 @@ package object client { exclusions = Seq("org.apache.curator:*", "org.pentaho:pentaho-aggdesigner-algorithm")) - val allSupportedHiveVersions = Set(v12, v13, v14, v1_0, v1_1, v1_2, v2_0) + case object v2_1 extends HiveVersion("2.1.1", + exclusions = Seq("org.apache.curator:*", + "org.pentaho:pentaho-aggdesigner-algorithm")) + + val allSupportedHiveVersions = Set(v12, v13, v14, v1_0, v1_1, v1_2, v2_0, v2_1) } // scalastyle:on diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index b8536d0c1bd5..3682dc850790 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -149,7 +149,7 @@ case class InsertIntoHiveTable( // staging directory under the table director for Hive prior to 1.1, the staging directory will // be removed by Hive when Hive is trying to empty the table directory. val hiveVersionsUsingOldExternalTempPath: Set[HiveVersion] = Set(v12, v13, v14, v1_0) - val hiveVersionsUsingNewExternalTempPath: Set[HiveVersion] = Set(v1_1, v1_2, v2_0) + val hiveVersionsUsingNewExternalTempPath: Set[HiveVersion] = Set(v1_1, v1_2, v2_0, v2_1) // Ensure all the supported versions are considered here. assert(hiveVersionsUsingNewExternalTempPath ++ hiveVersionsUsingOldExternalTempPath == diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala index cb1386111035..7aff49c0fc3b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala @@ -21,6 +21,7 @@ import java.io.{ByteArrayOutputStream, File, PrintStream} import java.net.URI import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.hive.common.StatsSetupConst import org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe import org.apache.hadoop.mapred.TextInputFormat @@ -108,7 +109,7 @@ class VersionsSuite extends SparkFunSuite with Logging { assert(getNestedMessages(e) contains "Unknown column 'A0.OWNER_NAME' in 'field list'") } - private val versions = Seq("0.12", "0.13", "0.14", "1.0", "1.1", "1.2", "2.0") + private val versions = Seq("0.12", "0.13", "0.14", "1.0", "1.1", "1.2", "2.0", "2.1") private var client: HiveClient = null @@ -120,10 +121,12 @@ class VersionsSuite extends SparkFunSuite with Logging { System.gc() // Hack to avoid SEGV on some JVM versions. val hadoopConf = new Configuration() hadoopConf.set("test", "success") - // Hive changed the default of datanucleus.schema.autoCreateAll from true to false since 2.0 - // For details, see the JIRA HIVE-6113 - if (version == "2.0") { + // Hive changed the default of datanucleus.schema.autoCreateAll from true to false and + // hive.metastore.schema.verification from false to true since 2.0 + // For details, see the JIRA HIVE-6113 and HIVE-12463 + if (version == "2.0" || version == "2.1") { hadoopConf.set("datanucleus.schema.autoCreateAll", "true") + hadoopConf.set("hive.metastore.schema.verification", "false") } client = buildClient(version, hadoopConf, HiveUtils.hiveClientConfigurations(hadoopConf)) if (versionSpark != null) versionSpark.reset() @@ -572,6 +575,14 @@ class VersionsSuite extends SparkFunSuite with Logging { withTable("tbl") { versionSpark.sql("CREATE TABLE tbl AS SELECT 1 AS a") assert(versionSpark.table("tbl").collect().toSeq == Seq(Row(1))) + val tableMeta = versionSpark.sessionState.catalog.getTableMetadata(TableIdentifier("tbl")) + val totalSize = tableMeta.properties.get(StatsSetupConst.TOTAL_SIZE).map(_.toLong) + // Except 0.12, all the following versions will fill the Hive-generated statistics + if (version == "0.12") { + assert(totalSize.isEmpty) + } else { + assert(totalSize.nonEmpty && totalSize.get > 0) + } } } From e1ac553402ab82bbc72fd64e5943b71c16b4b37d Mon Sep 17 00:00:00 2001 From: Liwei Lin Date: Tue, 14 Mar 2017 22:30:16 -0700 Subject: [PATCH 037/512] [SPARK-19817][SS] Make it clear that `timeZone` is a general option in DataStreamReader/Writer ## What changes were proposed in this pull request? As timezone setting can also affect partition values, it works for all formats, we should make it clear. ## How was this patch tested? N/A Author: Liwei Lin Closes #17299 from lw-lin/timezone. --- python/pyspark/sql/readwriter.py | 8 ++--- python/pyspark/sql/streaming.py | 32 ++++++++++++++----- .../apache/spark/sql/DataFrameReader.scala | 6 ++-- .../apache/spark/sql/DataFrameWriter.scala | 6 ++-- .../sql/streaming/DataStreamReader.scala | 22 ++++++++++--- .../sql/streaming/DataStreamWriter.scala | 18 +++++++++++ 6 files changed, 70 insertions(+), 22 deletions(-) diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 705803791d89..122e17f2020f 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -112,7 +112,7 @@ def option(self, key, value): You can set the following option(s) for reading files: * ``timeZone``: sets the string that indicates a timezone to be used to parse timestamps - in the JSON/CSV datasources or parttion values. + in the JSON/CSV datasources or partition values. If it isn't set, it uses the default value, session local timezone. """ self._jreader = self._jreader.option(key, to_str(value)) @@ -124,7 +124,7 @@ def options(self, **options): You can set the following option(s) for reading files: * ``timeZone``: sets the string that indicates a timezone to be used to parse timestamps - in the JSON/CSV datasources or parttion values. + in the JSON/CSV datasources or partition values. If it isn't set, it uses the default value, session local timezone. """ for k in options: @@ -530,7 +530,7 @@ def option(self, key, value): You can set the following option(s) for writing files: * ``timeZone``: sets the string that indicates a timezone to be used to format - timestamps in the JSON/CSV datasources or parttion values. + timestamps in the JSON/CSV datasources or partition values. If it isn't set, it uses the default value, session local timezone. """ self._jwrite = self._jwrite.option(key, to_str(value)) @@ -542,7 +542,7 @@ def options(self, **options): You can set the following option(s) for writing files: * ``timeZone``: sets the string that indicates a timezone to be used to format - timestamps in the JSON/CSV datasources or parttion values. + timestamps in the JSON/CSV datasources or partition values. If it isn't set, it uses the default value, session local timezone. """ for k in options: diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index 625fb9ba385a..288cc1e4f64d 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -373,6 +373,11 @@ def schema(self, schema): def option(self, key, value): """Adds an input option for the underlying data source. + You can set the following option(s) for reading files: + * ``timeZone``: sets the string that indicates a timezone to be used to parse timestamps + in the JSON/CSV datasources or partition values. + If it isn't set, it uses the default value, session local timezone. + .. note:: Experimental. >>> s = spark.readStream.option("x", 1) @@ -384,6 +389,11 @@ def option(self, key, value): def options(self, **options): """Adds input options for the underlying data source. + You can set the following option(s) for reading files: + * ``timeZone``: sets the string that indicates a timezone to be used to parse timestamps + in the JSON/CSV datasources or partition values. + If it isn't set, it uses the default value, session local timezone. + .. note:: Experimental. >>> s = spark.readStream.options(x="1", y=2) @@ -429,7 +439,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, allowComments=None, allowUnquotedFieldNames=None, allowSingleQuotes=None, allowNumericLeadingZero=None, allowBackslashEscapingAnyCharacter=None, mode=None, columnNameOfCorruptRecord=None, dateFormat=None, timestampFormat=None, - timeZone=None, wholeFile=None): + wholeFile=None): """ Loads a JSON file stream and returns the results as a :class:`DataFrame`. @@ -486,8 +496,6 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, formats follow the formats at ``java.text.SimpleDateFormat``. This applies to timestamp type. If None is set, it uses the default value, ``yyyy-MM-dd'T'HH:mm:ss.SSSZZ``. - :param timeZone: sets the string that indicates a timezone to be used to parse timestamps. - If None is set, it uses the default value, session local timezone. :param wholeFile: parse one record, which may span multiple lines, per file. If None is set, it uses the default value, ``false``. @@ -503,7 +511,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, allowSingleQuotes=allowSingleQuotes, allowNumericLeadingZero=allowNumericLeadingZero, allowBackslashEscapingAnyCharacter=allowBackslashEscapingAnyCharacter, mode=mode, columnNameOfCorruptRecord=columnNameOfCorruptRecord, dateFormat=dateFormat, - timestampFormat=timestampFormat, timeZone=timeZone, wholeFile=wholeFile) + timestampFormat=timestampFormat, wholeFile=wholeFile) if isinstance(path, basestring): return self._df(self._jreader.json(path)) else: @@ -561,7 +569,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non comment=None, header=None, inferSchema=None, ignoreLeadingWhiteSpace=None, ignoreTrailingWhiteSpace=None, nullValue=None, nanValue=None, positiveInf=None, negativeInf=None, dateFormat=None, timestampFormat=None, maxColumns=None, - maxCharsPerColumn=None, maxMalformedLogPerPartition=None, mode=None, timeZone=None, + maxCharsPerColumn=None, maxMalformedLogPerPartition=None, mode=None, columnNameOfCorruptRecord=None, wholeFile=None): """Loads a CSV file stream and returns the result as a :class:`DataFrame`. @@ -619,8 +627,6 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non ``-1`` meaning unlimited length. :param mode: allows a mode for dealing with corrupt records during parsing. If None is set, it uses the default value, ``PERMISSIVE``. - :param timeZone: sets the string that indicates a timezone to be used to parse timestamps. - If None is set, it uses the default value, session local timezone. * ``PERMISSIVE`` : sets other fields to ``null`` when it meets a corrupted \ record, and puts the malformed string into a field configured by \ @@ -653,7 +659,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non nanValue=nanValue, positiveInf=positiveInf, negativeInf=negativeInf, dateFormat=dateFormat, timestampFormat=timestampFormat, maxColumns=maxColumns, maxCharsPerColumn=maxCharsPerColumn, - maxMalformedLogPerPartition=maxMalformedLogPerPartition, mode=mode, timeZone=timeZone, + maxMalformedLogPerPartition=maxMalformedLogPerPartition, mode=mode, columnNameOfCorruptRecord=columnNameOfCorruptRecord, wholeFile=wholeFile) if isinstance(path, basestring): return self._df(self._jreader.csv(path)) @@ -721,6 +727,11 @@ def format(self, source): def option(self, key, value): """Adds an output option for the underlying data source. + You can set the following option(s) for writing files: + * ``timeZone``: sets the string that indicates a timezone to be used to format + timestamps in the JSON/CSV datasources or partition values. + If it isn't set, it uses the default value, session local timezone. + .. note:: Experimental. """ self._jwrite = self._jwrite.option(key, to_str(value)) @@ -730,6 +741,11 @@ def option(self, key, value): def options(self, **options): """Adds output options for the underlying data source. + You can set the following option(s) for writing files: + * ``timeZone``: sets the string that indicates a timezone to be used to format + timestamps in the JSON/CSV datasources or partition values. + If it isn't set, it uses the default value, session local timezone. + .. note:: Experimental. """ for k in options: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 309654c80414..88fbfb4c92a0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -73,7 +73,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * You can set the following option(s): *
      *
    • `timeZone` (default session local timezone): sets the string that indicates a timezone - * to be used to parse timestamps in the JSON/CSV datasources or parttion values.
    • + * to be used to parse timestamps in the JSON/CSV datasources or partition values. *
    * * @since 1.4.0 @@ -110,7 +110,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * You can set the following option(s): *
      *
    • `timeZone` (default session local timezone): sets the string that indicates a timezone - * to be used to parse timestamps in the JSON/CSV datasources or parttion values.
    • + * to be used to parse timestamps in the JSON/CSV datasources or partition values. *
    * * @since 1.4.0 @@ -126,7 +126,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * You can set the following option(s): *
      *
    • `timeZone` (default session local timezone): sets the string that indicates a timezone - * to be used to parse timestamps in the JSON/CSV datasources or parttion values.
    • + * to be used to parse timestamps in the JSON/CSV datasources or partition values. *
    * * @since 1.4.0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 608160a214fb..deaa8006945c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -93,7 +93,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { * You can set the following option(s): *
      *
    • `timeZone` (default session local timezone): sets the string that indicates a timezone - * to be used to format timestamps in the JSON/CSV datasources or parttion values.
    • + * to be used to format timestamps in the JSON/CSV datasources or partition values. *
    * * @since 1.4.0 @@ -130,7 +130,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { * You can set the following option(s): *
      *
    • `timeZone` (default session local timezone): sets the string that indicates a timezone - * to be used to format timestamps in the JSON/CSV datasources or parttion values.
    • + * to be used to format timestamps in the JSON/CSV datasources or partition values. *
    * * @since 1.4.0 @@ -146,7 +146,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { * You can set the following option(s): *
      *
    • `timeZone` (default session local timezone): sets the string that indicates a timezone - * to be used to format timestamps in the JSON/CSV datasources or parttion values.
    • + * to be used to format timestamps in the JSON/CSV datasources or partition values. *
    * * @since 1.4.0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index aed8074a64d5..388ef182ce3a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -61,6 +61,12 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo /** * Adds an input option for the underlying data source. * + * You can set the following option(s): + *
      + *
    • `timeZone` (default session local timezone): sets the string that indicates a timezone + * to be used to parse timestamps in the JSON/CSV datasources or partition values.
    • + *
    + * * @since 2.0.0 */ def option(key: String, value: String): DataStreamReader = { @@ -92,6 +98,12 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo /** * (Scala-specific) Adds input options for the underlying data source. * + * You can set the following option(s): + *
      + *
    • `timeZone` (default session local timezone): sets the string that indicates a timezone + * to be used to parse timestamps in the JSON/CSV datasources or partition values.
    • + *
    + * * @since 2.0.0 */ def options(options: scala.collection.Map[String, String]): DataStreamReader = { @@ -102,6 +114,12 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo /** * Adds input options for the underlying data source. * + * You can set the following option(s): + *
      + *
    • `timeZone` (default session local timezone): sets the string that indicates a timezone + * to be used to parse timestamps in the JSON/CSV datasources or partition values.
    • + *
    + * * @since 2.0.0 */ def options(options: java.util.Map[String, String]): DataStreamReader = { @@ -186,8 +204,6 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo *
  • `timestampFormat` (default `yyyy-MM-dd'T'HH:mm:ss.SSSZZ`): sets the string that * indicates a timestamp format. Custom date formats follow the formats at * `java.text.SimpleDateFormat`. This applies to timestamp type.
  • - *
  • `timeZone` (default session local timezone): sets the string that indicates a timezone - * to be used to parse timestamps.
  • *
  • `wholeFile` (default `false`): parse one record, which may span multiple lines, * per file
  • * @@ -239,8 +255,6 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo *
  • `timestampFormat` (default `yyyy-MM-dd'T'HH:mm:ss.SSSZZ`): sets the string that * indicates a timestamp format. Custom date formats follow the formats at * `java.text.SimpleDateFormat`. This applies to timestamp type.
  • - *
  • `timeZone` (default session local timezone): sets the string that indicates a timezone - * to be used to parse timestamps.
  • *
  • `maxColumns` (default `20480`): defines a hard limit of how many columns * a record can have.
  • *
  • `maxCharsPerColumn` (default `-1`): defines the maximum number of characters allowed diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala index c8fda8cd8359..fe52013badb6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala @@ -145,6 +145,12 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { /** * Adds an output option for the underlying data source. * + * You can set the following option(s): + *
      + *
    • `timeZone` (default session local timezone): sets the string that indicates a timezone + * to be used to format timestamps in the JSON/CSV datasources or partition values.
    • + *
    + * * @since 2.0.0 */ def option(key: String, value: String): DataStreamWriter[T] = { @@ -176,6 +182,12 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { /** * (Scala-specific) Adds output options for the underlying data source. * + * You can set the following option(s): + *
      + *
    • `timeZone` (default session local timezone): sets the string that indicates a timezone + * to be used to format timestamps in the JSON/CSV datasources or partition values.
    • + *
    + * * @since 2.0.0 */ def options(options: scala.collection.Map[String, String]): DataStreamWriter[T] = { @@ -186,6 +198,12 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { /** * Adds output options for the underlying data source. * + * You can set the following option(s): + *
      + *
    • `timeZone` (default session local timezone): sets the string that indicates a timezone + * to be used to format timestamps in the JSON/CSV datasources or partition values.
    • + *
    + * * @since 2.0.0 */ def options(options: java.util.Map[String, String]): DataStreamWriter[T] = { From ee36bc1c9043ead3c3ba4fba7e68c6c47ad7ae7a Mon Sep 17 00:00:00 2001 From: jiangxingbo Date: Tue, 14 Mar 2017 23:57:54 -0700 Subject: [PATCH 038/512] [SPARK-19877][SQL] Restrict the nested level of a view ## What changes were proposed in this pull request? We should restrict the nested level of a view, to avoid stack overflow exception during the view resolution. ## How was this patch tested? Add new test case in `SQLViewSuite`. Author: jiangxingbo Closes #17241 from jiangxb1987/view-depth. --- .../sql/catalyst/SimpleCatalystConf.scala | 3 ++- .../sql/catalyst/analysis/Analyzer.scala | 13 +++++++--- .../apache/spark/sql/internal/SQLConf.scala | 15 +++++++++++ .../spark/sql/execution/SQLViewSuite.scala | 25 +++++++++++++++++++ 4 files changed, 51 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SimpleCatalystConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SimpleCatalystConf.scala index 746f84459de2..0d4903e03bf5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SimpleCatalystConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SimpleCatalystConf.scala @@ -41,7 +41,8 @@ case class SimpleCatalystConf( override val joinReorderEnabled: Boolean = false, override val joinReorderDPThreshold: Int = 12, override val warehousePath: String = "/user/hive/warehouse", - override val sessionLocalTimeZone: String = TimeZone.getDefault().getID) + override val sessionLocalTimeZone: String = TimeZone.getDefault().getID, + override val maxNestedViewDepth: Int = 100) extends SQLConf { override def clone(): SimpleCatalystConf = this.copy() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index a3764d8c843d..68a4746a54d9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -58,13 +58,12 @@ object SimpleAnalyzer extends Analyzer( * * @param defaultDatabase The default database used in the view resolution, this overrules the * current catalog database. - * @param nestedViewLevel The nested level in the view resolution, this enables us to limit the + * @param nestedViewDepth The nested depth in the view resolution, this enables us to limit the * depth of nested views. - * TODO Limit the depth of nested views. */ case class AnalysisContext( defaultDatabase: Option[String] = None, - nestedViewLevel: Int = 0) + nestedViewDepth: Int = 0) object AnalysisContext { private val value = new ThreadLocal[AnalysisContext]() { @@ -77,7 +76,7 @@ object AnalysisContext { def withAnalysisContext[A](database: Option[String])(f: => A): A = { val originContext = value.get() val context = AnalysisContext(defaultDatabase = database, - nestedViewLevel = originContext.nestedViewLevel + 1) + nestedViewDepth = originContext.nestedViewDepth + 1) set(context) try f finally { set(originContext) } } @@ -598,6 +597,12 @@ class Analyzer( case view @ View(desc, _, child) if !child.resolved => // Resolve all the UnresolvedRelations and Views in the child. val newChild = AnalysisContext.withAnalysisContext(desc.viewDefaultDatabase) { + if (AnalysisContext.get.nestedViewDepth > conf.maxNestedViewDepth) { + view.failAnalysis(s"The depth of view ${view.desc.identifier} exceeds the maximum " + + s"view resolution depth (${conf.maxNestedViewDepth}). Analysis is aborted to " + + "avoid errors. Increase the value of spark.sql.view.maxNestedViewDepth to work " + + "aroud this.") + } execute(child) } view.copy(child = newChild) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 315bedb12e71..8f65672d5a83 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -571,6 +571,19 @@ object SQLConf { .booleanConf .createWithDefault(true) + val MAX_NESTED_VIEW_DEPTH = + buildConf("spark.sql.view.maxNestedViewDepth") + .internal() + .doc("The maximum depth of a view reference in a nested view. A nested view may reference " + + "other nested views, the dependencies are organized in a directed acyclic graph (DAG). " + + "However the DAG depth may become too large and cause unexpected behavior. This " + + "configuration puts a limit on this: when the depth of a view exceeds this value during " + + "analysis, we terminate the resolution to avoid potential errors.") + .intConf + .checkValue(depth => depth > 0, "The maximum depth of a view reference in a nested view " + + "must be positive.") + .createWithDefault(100) + val STREAMING_FILE_COMMIT_PROTOCOL_CLASS = buildConf("spark.sql.streaming.commitProtocolClass") .internal() @@ -932,6 +945,8 @@ class SQLConf extends Serializable with Logging { def joinReorderDPThreshold: Int = getConf(SQLConf.JOIN_REORDER_DP_THRESHOLD) + def maxNestedViewDepth: Int = getConf(SQLConf.MAX_NESTED_VIEW_DEPTH) + /** ********************** SQLConf functionality methods ************ */ /** Set Spark SQL configuration properties. */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala index 2ca2206bb9d4..d32716c18ddf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala @@ -644,4 +644,29 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils { "-> `default`.`view2` -> `default`.`view1`)")) } } + + test("restrict the nested level of a view") { + val viewNames = Array.range(0, 11).map(idx => s"view$idx") + withView(viewNames: _*) { + sql("CREATE VIEW view0 AS SELECT * FROM jt") + Array.range(0, 10).foreach { idx => + sql(s"CREATE VIEW view${idx + 1} AS SELECT * FROM view$idx") + } + + withSQLConf("spark.sql.view.maxNestedViewDepth" -> "10") { + val e = intercept[AnalysisException] { + sql("SELECT * FROM view10") + }.getMessage + assert(e.contains("The depth of view `default`.`view0` exceeds the maximum view " + + "resolution depth (10). Analysis is aborted to avoid errors. Increase the value " + + "of spark.sql.view.maxNestedViewDepth to work aroud this.")) + } + + val e = intercept[IllegalArgumentException] { + withSQLConf("spark.sql.view.maxNestedViewDepth" -> "0") {} + }.getMessage + assert(e.contains("The maximum depth of a view reference in a nested view must be " + + "positive.")) + } + } } From 9ff85be3bd6bf3a782c0e52fa9c2598d79f310bb Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Wed, 15 Mar 2017 10:46:05 +0100 Subject: [PATCH 039/512] [SPARK-19889][SQL] Make TaskContext callbacks thread safe ## What changes were proposed in this pull request? It is sometimes useful to use multiple threads in a task to parallelize tasks. These threads might register some completion/failure listeners to clean up when the task completes or fails. We currently cannot register such a callback and be sure that it will get called, because the context might be in the process of invoking its callbacks, when the the callback gets registered. This PR improves this by making sure that you cannot add a completion/failure listener from a different thread when the context is being marked as completed/failed in another thread. This is done by synchronizing these methods on the task context itself. Failure listeners were called only once. Completion listeners now follow the same pattern; this lifts the idempotency requirement for completion listeners and makes it easier to implement them. In some cases we can (accidentally) add a completion/failure listener after the fact, these listeners will be called immediately in order make sure we can safely clean-up after a task. As a result of this change we could make the `failure` and `completed` flags non-volatile. The `isCompleted()` method now uses synchronization to ensure that updates are visible across threads. ## How was this patch tested? Adding tests to `TaskContestSuite` to test adding listeners to a completed/failed context. Author: Herman van Hovell Closes #17244 from hvanhovell/SPARK-19889. --- .../scala/org/apache/spark/TaskContext.scala | 16 ++-- .../org/apache/spark/TaskContextImpl.scala | 85 +++++++++++++------ .../spark/scheduler/TaskContextSuite.scala | 26 ++++++ 3 files changed, 93 insertions(+), 34 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala index f0867ecb16ea..5acfce17593b 100644 --- a/core/src/main/scala/org/apache/spark/TaskContext.scala +++ b/core/src/main/scala/org/apache/spark/TaskContext.scala @@ -105,7 +105,9 @@ abstract class TaskContext extends Serializable { /** * Adds a (Java friendly) listener to be executed on task completion. - * This will be called in all situation - success, failure, or cancellation. + * This will be called in all situations - success, failure, or cancellation. Adding a listener + * to an already completed task will result in that listener being called immediately. + * * An example use is for HadoopRDD to register a callback to close the input stream. * * Exceptions thrown by the listener will result in failure of the task. @@ -114,7 +116,9 @@ abstract class TaskContext extends Serializable { /** * Adds a listener in the form of a Scala closure to be executed on task completion. - * This will be called in all situations - success, failure, or cancellation. + * This will be called in all situations - success, failure, or cancellation. Adding a listener + * to an already completed task will result in that listener being called immediately. + * * An example use is for HadoopRDD to register a callback to close the input stream. * * Exceptions thrown by the listener will result in failure of the task. @@ -126,14 +130,14 @@ abstract class TaskContext extends Serializable { } /** - * Adds a listener to be executed on task failure. - * Operations defined here must be idempotent, as `onTaskFailure` can be called multiple times. + * Adds a listener to be executed on task failure. Adding a listener to an already failed task + * will result in that listener being called immediately. */ def addTaskFailureListener(listener: TaskFailureListener): TaskContext /** - * Adds a listener to be executed on task failure. - * Operations defined here must be idempotent, as `onTaskFailure` can be called multiple times. + * Adds a listener to be executed on task failure. Adding a listener to an already failed task + * will result in that listener being called immediately. */ def addTaskFailureListener(f: (TaskContext, Throwable) => Unit): TaskContext = { addTaskFailureListener(new TaskFailureListener { diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala index dc0d12878550..ea8dcdfd5d7d 100644 --- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala +++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala @@ -18,6 +18,7 @@ package org.apache.spark import java.util.Properties +import javax.annotation.concurrent.GuardedBy import scala.collection.mutable.ArrayBuffer @@ -29,6 +30,16 @@ import org.apache.spark.metrics.source.Source import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.util._ +/** + * A [[TaskContext]] implementation. + * + * A small note on thread safety. The interrupted & fetchFailed fields are volatile, this makes + * sure that updates are always visible across threads. The complete & failed flags and their + * callbacks are protected by locking on the context instance. For instance, this ensures + * that you cannot add a completion listener in one thread while we are completing (and calling + * the completion listeners) in another thread. Other state is immutable, however the exposed + * [[TaskMetrics]] & [[MetricsSystem]] objects are not thread safe. + */ private[spark] class TaskContextImpl( val stageId: Int, val partitionId: Int, @@ -52,62 +63,79 @@ private[spark] class TaskContextImpl( @volatile private var interrupted: Boolean = false // Whether the task has completed. - @volatile private var completed: Boolean = false + private var completed: Boolean = false // Whether the task has failed. - @volatile private var failed: Boolean = false + private var failed: Boolean = false + + // Throwable that caused the task to fail + private var failure: Throwable = _ // If there was a fetch failure in the task, we store it here, to make sure user-code doesn't // hide the exception. See SPARK-19276 @volatile private var _fetchFailedException: Option[FetchFailedException] = None - override def addTaskCompletionListener(listener: TaskCompletionListener): this.type = { - onCompleteCallbacks += listener + @GuardedBy("this") + override def addTaskCompletionListener(listener: TaskCompletionListener) + : this.type = synchronized { + if (completed) { + listener.onTaskCompletion(this) + } else { + onCompleteCallbacks += listener + } this } - override def addTaskFailureListener(listener: TaskFailureListener): this.type = { - onFailureCallbacks += listener + @GuardedBy("this") + override def addTaskFailureListener(listener: TaskFailureListener) + : this.type = synchronized { + if (failed) { + listener.onTaskFailure(this, failure) + } else { + onFailureCallbacks += listener + } this } /** Marks the task as failed and triggers the failure listeners. */ - private[spark] def markTaskFailed(error: Throwable): Unit = { - // failure callbacks should only be called once + @GuardedBy("this") + private[spark] def markTaskFailed(error: Throwable): Unit = synchronized { if (failed) return failed = true - val errorMsgs = new ArrayBuffer[String](2) - // Process failure callbacks in the reverse order of registration - onFailureCallbacks.reverse.foreach { listener => - try { - listener.onTaskFailure(this, error) - } catch { - case e: Throwable => - errorMsgs += e.getMessage - logError("Error in TaskFailureListener", e) - } - } - if (errorMsgs.nonEmpty) { - throw new TaskCompletionListenerException(errorMsgs, Option(error)) + failure = error + invokeListeners(onFailureCallbacks, "TaskFailureListener", Option(error)) { + _.onTaskFailure(this, error) } } /** Marks the task as completed and triggers the completion listeners. */ - private[spark] def markTaskCompleted(): Unit = { + @GuardedBy("this") + private[spark] def markTaskCompleted(): Unit = synchronized { + if (completed) return completed = true + invokeListeners(onCompleteCallbacks, "TaskCompletionListener", None) { + _.onTaskCompletion(this) + } + } + + private def invokeListeners[T]( + listeners: Seq[T], + name: String, + error: Option[Throwable])( + callback: T => Unit): Unit = { val errorMsgs = new ArrayBuffer[String](2) - // Process complete callbacks in the reverse order of registration - onCompleteCallbacks.reverse.foreach { listener => + // Process callbacks in the reverse order of registration + listeners.reverse.foreach { listener => try { - listener.onTaskCompletion(this) + callback(listener) } catch { case e: Throwable => errorMsgs += e.getMessage - logError("Error in TaskCompletionListener", e) + logError(s"Error in $name", e) } } if (errorMsgs.nonEmpty) { - throw new TaskCompletionListenerException(errorMsgs) + throw new TaskCompletionListenerException(errorMsgs, error) } } @@ -116,7 +144,8 @@ private[spark] class TaskContextImpl( interrupted = true } - override def isCompleted(): Boolean = completed + @GuardedBy("this") + override def isCompleted(): Boolean = synchronized(completed) override def isRunningLocally(): Boolean = false diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala index 7004128308af..8f576daa77d1 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala @@ -228,6 +228,32 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark assert(res === Array("testPropValue,testPropValue")) } + test("immediately call a completion listener if the context is completed") { + var invocations = 0 + val context = TaskContext.empty() + context.markTaskCompleted() + context.addTaskCompletionListener(_ => invocations += 1) + assert(invocations == 1) + context.markTaskCompleted() + assert(invocations == 1) + } + + test("immediately call a failure listener if the context has failed") { + var invocations = 0 + var lastError: Throwable = null + val error = new RuntimeException + val context = TaskContext.empty() + context.markTaskFailed(error) + context.addTaskFailureListener { (_, e) => + lastError = e + invocations += 1 + } + assert(lastError == error) + assert(invocations == 1) + context.markTaskFailed(error) + assert(lastError == error) + assert(invocations == 1) + } } private object TaskContextSuite { From 7387126f83dc0489eb1df734bfeba705709b7861 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Wed, 15 Mar 2017 10:17:18 -0700 Subject: [PATCH 040/512] [SPARK-19872] [PYTHON] Use the correct deserializer for RDD construction for coalesce/repartition ## What changes were proposed in this pull request? This PR proposes to use the correct deserializer, `BatchedSerializer` for RDD construction for coalesce/repartition when the shuffle is enabled. Currently, it is passing `UTF8Deserializer` as is not `BatchedSerializer` from the copied one. with the file, `text.txt` below: ``` a b d e f g h i j k l ``` - Before ```python >>> sc.textFile('text.txt').repartition(1).collect() ``` ``` UTF8Deserializer(True) Traceback (most recent call last): File "", line 1, in File ".../spark/python/pyspark/rdd.py", line 811, in collect return list(_load_from_socket(port, self._jrdd_deserializer)) File ".../spark/python/pyspark/serializers.py", line 549, in load_stream yield self.loads(stream) File ".../spark/python/pyspark/serializers.py", line 544, in loads return s.decode("utf-8") if self.use_unicode else s File "/System/Library/Frameworks/Python.framework/Versions/2.7/lib/python2.7/encodings/utf_8.py", line 16, in decode return codecs.utf_8_decode(input, errors, True) UnicodeDecodeError: 'utf8' codec can't decode byte 0x80 in position 0: invalid start byte ``` - After ```python >>> sc.textFile('text.txt').repartition(1).collect() ``` ``` [u'a', u'b', u'', u'd', u'e', u'f', u'g', u'h', u'i', u'j', u'k', u'l', u''] ``` ## How was this patch tested? Unit test in `python/pyspark/tests.py`. Author: hyukjinkwon Closes #17282 from HyukjinKwon/SPARK-19872. --- python/pyspark/rdd.py | 4 +++- python/pyspark/tests.py | 6 ++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index a5e6e2b05496..291c1caaaed5 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -2072,10 +2072,12 @@ def coalesce(self, numPartitions, shuffle=False): batchSize = min(10, self.ctx._batchSize or 1024) ser = BatchedSerializer(PickleSerializer(), batchSize) selfCopy = self._reserialize(ser) + jrdd_deserializer = selfCopy._jrdd_deserializer jrdd = selfCopy._jrdd.coalesce(numPartitions, shuffle) else: + jrdd_deserializer = self._jrdd_deserializer jrdd = self._jrdd.coalesce(numPartitions, shuffle) - return RDD(jrdd, self.ctx, self._jrdd_deserializer) + return RDD(jrdd, self.ctx, jrdd_deserializer) def zip(self, other): """ diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index c6c87a9ea555..bb13de563cdd 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -1037,6 +1037,12 @@ def test_repartition_no_skewed(self): zeros = len([x for x in l if x == 0]) self.assertTrue(zeros == 0) + def test_repartition_on_textfile(self): + path = os.path.join(SPARK_HOME, "python/test_support/hello/hello.txt") + rdd = self.sc.textFile(path) + result = rdd.repartition(1).collect() + self.assertEqual(u"Hello World!", result[0]) + def test_distinct(self): rdd = self.sc.parallelize((1, 2, 3)*10, 10) self.assertEqual(rdd.getNumPartitions(), 10) From 02c274eaba0a8e7611226e0d4e93d3c36253f4ce Mon Sep 17 00:00:00 2001 From: Tejas Patil Date: Wed, 15 Mar 2017 20:18:39 +0100 Subject: [PATCH 041/512] [SPARK-13450] Introduce ExternalAppendOnlyUnsafeRowArray. Change CartesianProductExec, SortMergeJoin, WindowExec to use it ## What issue does this PR address ? Jira: https://issues.apache.org/jira/browse/SPARK-13450 In `SortMergeJoinExec`, rows of the right relation having the same value for a join key are buffered in-memory. In case of skew, this causes OOMs (see comments in SPARK-13450 for more details). Heap dump from a failed job confirms this : https://issues.apache.org/jira/secure/attachment/12846382/heap-dump-analysis.png . While its possible to increase the heap size to workaround, Spark should be resilient to such issues as skews can happen arbitrarily. ## Change proposed in this pull request - Introduces `ExternalAppendOnlyUnsafeRowArray` - It holds `UnsafeRow`s in-memory upto a certain threshold. - After the threshold is hit, it switches to `UnsafeExternalSorter` which enables spilling of the rows to disk. It does NOT sort the data. - Allows iterating the array multiple times. However, any alteration to the array (using `add` or `clear`) will invalidate the existing iterator(s) - `WindowExec` was already using `UnsafeExternalSorter` to support spilling. Changed it to use the new array - Changed `SortMergeJoinExec` to use the new array implementation - NOTE: I have not changed FULL OUTER JOIN to use this new array implementation. Changing that will need more surgery and I will rather put up a separate PR for that once this gets in. - Changed `CartesianProductExec` to use the new array implementation #### Note for reviewers The diff can be divided into 3 parts. My motive behind having all the changes in a single PR was to demonstrate that the API is sane and supports 2 use cases. If reviewing as 3 separate PRs would help, I am happy to make the split. ## How was this patch tested ? #### Unit testing - Added unit tests `ExternalAppendOnlyUnsafeRowArray` to validate all its APIs and access patterns - Added unit test for `SortMergeExec` - with and without spill for inner join, left outer join, right outer join to confirm that the spill threshold config behaves as expected and output is as expected. - This PR touches the scanning logic in `SortMergeExec` for _all_ joins (except FULL OUTER JOIN). However, I expect existing test cases to cover that there is no regression in correctness. - Added unit test for `WindowExec` to check behavior of spilling and correctness of results. #### Stress testing - Confirmed that OOM is gone by running against a production job which used to OOM - Since I cannot share details about prod workload externally, created synthetic data to mimic the issue. Ran before and after the fix to demonstrate the issue and query success with this PR Generating the synthetic data ``` ./bin/spark-shell --driver-memory=6G import org.apache.spark.sql._ val hc = SparkSession.builder.master("local").getOrCreate() hc.sql("DROP TABLE IF EXISTS spark_13450_large_table").collect hc.sql("DROP TABLE IF EXISTS spark_13450_one_row_table").collect val df1 = (0 until 1).map(i => ("10", "100", i.toString, (i * 2).toString)).toDF("i", "j", "str1", "str2") df1.write.format("org.apache.spark.sql.hive.orc.OrcFileFormat").bucketBy(100, "i", "j").sortBy("i", "j").saveAsTable("spark_13450_one_row_table") val df2 = (0 until 3000000).map(i => ("10", "100", i.toString, (i * 2).toString)).toDF("i", "j", "str1", "str2") df2.write.format("org.apache.spark.sql.hive.orc.OrcFileFormat").bucketBy(100, "i", "j").sortBy("i", "j").saveAsTable("spark_13450_large_table") ``` Ran this against trunk VS local build with this PR. OOM repros with trunk and with the fix this query runs fine. ``` ./bin/spark-shell --driver-java-options="-XX:+HeapDumpOnOutOfMemoryError -XX:HeapDumpPath=/tmp/spark.driver.heapdump.hprof" import org.apache.spark.sql._ val hc = SparkSession.builder.master("local").getOrCreate() hc.sql("SET spark.sql.autoBroadcastJoinThreshold=1") hc.sql("SET spark.sql.sortMergeJoinExec.buffer.spill.threshold=10000") hc.sql("DROP TABLE IF EXISTS spark_13450_result").collect hc.sql(""" CREATE TABLE spark_13450_result AS SELECT a.i AS a_i, a.j AS a_j, a.str1 AS a_str1, a.str2 AS a_str2, b.i AS b_i, b.j AS b_j, b.str1 AS b_str1, b.str2 AS b_str2 FROM spark_13450_one_row_table a JOIN spark_13450_large_table b ON a.i=b.i AND a.j=b.j """) ``` ## Performance comparison ### Macro-benchmark I ran a SMB join query over two real world tables (2 trillion rows (40 TB) and 6 million rows (120 GB)). Note that this dataset does not have skew so no spill happened. I saw improvement in CPU time by 2-4% over version without this PR. This did not add up as I was expected some regression. I think allocating array of capacity of 128 at the start (instead of starting with default size 16) is the sole reason for the perf. gain : https://github.com/tejasapatil/spark/blob/SPARK-13450_smb_buffer_oom/sql/core/src/main/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArray.scala#L43 . I could remove that and rerun, but effectively the change will be deployed in this form and I wanted to see the effect of it over large workload. ### Micro-benchmark Two types of benchmarking can be found in `ExternalAppendOnlyUnsafeRowArrayBenchmark`: [A] Comparing `ExternalAppendOnlyUnsafeRowArray` against raw `ArrayBuffer` when all rows fit in-memory and there is no spill ``` Array with 1000 rows: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ ArrayBuffer 7821 / 7941 33.5 29.8 1.0X ExternalAppendOnlyUnsafeRowArray 8798 / 8819 29.8 33.6 0.9X Array with 30000 rows: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ ArrayBuffer 19200 / 19206 25.6 39.1 1.0X ExternalAppendOnlyUnsafeRowArray 19558 / 19562 25.1 39.8 1.0X Array with 100000 rows: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ ArrayBuffer 5949 / 6028 17.2 58.1 1.0X ExternalAppendOnlyUnsafeRowArray 6078 / 6138 16.8 59.4 1.0X ``` [B] Comparing `ExternalAppendOnlyUnsafeRowArray` against raw `UnsafeExternalSorter` when there is spilling of data ``` Spilling with 1000 rows: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ UnsafeExternalSorter 9239 / 9470 28.4 35.2 1.0X ExternalAppendOnlyUnsafeRowArray 8857 / 8909 29.6 33.8 1.0X Spilling with 10000 rows: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ UnsafeExternalSorter 4 / 5 39.3 25.5 1.0X ExternalAppendOnlyUnsafeRowArray 5 / 6 29.8 33.5 0.8X ``` Author: Tejas Patil Closes #16909 from tejasapatil/SPARK-13450_smb_buffer_oom. --- .../apache/spark/sql/internal/SQLConf.scala | 30 ++ .../ExternalAppendOnlyUnsafeRowArray.scala | 243 ++++++++++++ .../joins/CartesianProductExec.scala | 52 +-- .../execution/joins/SortMergeJoinExec.scala | 117 +++--- .../sql/execution/window/RowBuffer.scala | 115 ------ .../sql/execution/window/WindowExec.scala | 72 +--- .../window/WindowFunctionFrame.scala | 97 +++-- .../org/apache/spark/sql/JoinSuite.scala | 136 ++++++- ...nalAppendOnlyUnsafeRowArrayBenchmark.scala | 233 ++++++++++++ ...xternalAppendOnlyUnsafeRowArraySuite.scala | 351 ++++++++++++++++++ .../execution/SQLWindowFunctionSuite.scala | 33 ++ 11 files changed, 1187 insertions(+), 292 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArray.scala delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/window/RowBuffer.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArrayBenchmark.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArraySuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 8f65672d5a83..a85f87aece45 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -29,6 +29,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ import org.apache.spark.network.util.ByteUnit import org.apache.spark.sql.catalyst.analysis.Resolver +import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter //////////////////////////////////////////////////////////////////////////////////////////////////// // This file defines the configuration options for Spark SQL. @@ -715,6 +716,27 @@ object SQLConf { .stringConf .createWithDefault(TimeZone.getDefault().getID()) + val WINDOW_EXEC_BUFFER_SPILL_THRESHOLD = + buildConf("spark.sql.windowExec.buffer.spill.threshold") + .internal() + .doc("Threshold for number of rows buffered in window operator") + .intConf + .createWithDefault(4096) + + val SORT_MERGE_JOIN_EXEC_BUFFER_SPILL_THRESHOLD = + buildConf("spark.sql.sortMergeJoinExec.buffer.spill.threshold") + .internal() + .doc("Threshold for number of rows buffered in sort merge join operator") + .intConf + .createWithDefault(Int.MaxValue) + + val CARTESIAN_PRODUCT_EXEC_BUFFER_SPILL_THRESHOLD = + buildConf("spark.sql.cartesianProductExec.buffer.spill.threshold") + .internal() + .doc("Threshold for number of rows buffered in cartesian product operator") + .intConf + .createWithDefault(UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD.toInt) + object Deprecated { val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" } @@ -945,6 +967,14 @@ class SQLConf extends Serializable with Logging { def joinReorderDPThreshold: Int = getConf(SQLConf.JOIN_REORDER_DP_THRESHOLD) + def windowExecBufferSpillThreshold: Int = getConf(WINDOW_EXEC_BUFFER_SPILL_THRESHOLD) + + def sortMergeJoinExecBufferSpillThreshold: Int = + getConf(SORT_MERGE_JOIN_EXEC_BUFFER_SPILL_THRESHOLD) + + def cartesianProductExecBufferSpillThreshold: Int = + getConf(CARTESIAN_PRODUCT_EXEC_BUFFER_SPILL_THRESHOLD) + def maxNestedViewDepth: Int = getConf(SQLConf.MAX_NESTED_VIEW_DEPTH) /** ********************** SQLConf functionality methods ************ */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArray.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArray.scala new file mode 100644 index 000000000000..458ac4ba3637 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArray.scala @@ -0,0 +1,243 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import java.util.ConcurrentModificationException + +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.{SparkEnv, TaskContext} +import org.apache.spark.internal.Logging +import org.apache.spark.memory.TaskMemoryManager +import org.apache.spark.serializer.SerializerManager +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.execution.ExternalAppendOnlyUnsafeRowArray.DefaultInitialSizeOfInMemoryBuffer +import org.apache.spark.storage.BlockManager +import org.apache.spark.util.collection.unsafe.sort.{UnsafeExternalSorter, UnsafeSorterIterator} + +/** + * An append-only array for [[UnsafeRow]]s that spills content to disk when there a predefined + * threshold of rows is reached. + * + * Setting spill threshold faces following trade-off: + * + * - If the spill threshold is too high, the in-memory array may occupy more memory than is + * available, resulting in OOM. + * - If the spill threshold is too low, we spill frequently and incur unnecessary disk writes. + * This may lead to a performance regression compared to the normal case of using an + * [[ArrayBuffer]] or [[Array]]. + */ +private[sql] class ExternalAppendOnlyUnsafeRowArray( + taskMemoryManager: TaskMemoryManager, + blockManager: BlockManager, + serializerManager: SerializerManager, + taskContext: TaskContext, + initialSize: Int, + pageSizeBytes: Long, + numRowsSpillThreshold: Int) extends Logging { + + def this(numRowsSpillThreshold: Int) { + this( + TaskContext.get().taskMemoryManager(), + SparkEnv.get.blockManager, + SparkEnv.get.serializerManager, + TaskContext.get(), + 1024, + SparkEnv.get.memoryManager.pageSizeBytes, + numRowsSpillThreshold) + } + + private val initialSizeOfInMemoryBuffer = + Math.min(DefaultInitialSizeOfInMemoryBuffer, numRowsSpillThreshold) + + private val inMemoryBuffer = if (initialSizeOfInMemoryBuffer > 0) { + new ArrayBuffer[UnsafeRow](initialSizeOfInMemoryBuffer) + } else { + null + } + + private var spillableArray: UnsafeExternalSorter = _ + private var numRows = 0 + + // A counter to keep track of total modifications done to this array since its creation. + // This helps to invalidate iterators when there are changes done to the backing array. + private var modificationsCount: Long = 0 + + private var numFieldsPerRow = 0 + + def length: Int = numRows + + def isEmpty: Boolean = numRows == 0 + + /** + * Clears up resources (eg. memory) held by the backing storage + */ + def clear(): Unit = { + if (spillableArray != null) { + // The last `spillableArray` of this task will be cleaned up via task completion listener + // inside `UnsafeExternalSorter` + spillableArray.cleanupResources() + spillableArray = null + } else if (inMemoryBuffer != null) { + inMemoryBuffer.clear() + } + numFieldsPerRow = 0 + numRows = 0 + modificationsCount += 1 + } + + def add(unsafeRow: UnsafeRow): Unit = { + if (numRows < numRowsSpillThreshold) { + inMemoryBuffer += unsafeRow.copy() + } else { + if (spillableArray == null) { + logInfo(s"Reached spill threshold of $numRowsSpillThreshold rows, switching to " + + s"${classOf[UnsafeExternalSorter].getName}") + + // We will not sort the rows, so prefixComparator and recordComparator are null + spillableArray = UnsafeExternalSorter.create( + taskMemoryManager, + blockManager, + serializerManager, + taskContext, + null, + null, + initialSize, + pageSizeBytes, + numRowsSpillThreshold, + false) + + // populate with existing in-memory buffered rows + if (inMemoryBuffer != null) { + inMemoryBuffer.foreach(existingUnsafeRow => + spillableArray.insertRecord( + existingUnsafeRow.getBaseObject, + existingUnsafeRow.getBaseOffset, + existingUnsafeRow.getSizeInBytes, + 0, + false) + ) + inMemoryBuffer.clear() + } + numFieldsPerRow = unsafeRow.numFields() + } + + spillableArray.insertRecord( + unsafeRow.getBaseObject, + unsafeRow.getBaseOffset, + unsafeRow.getSizeInBytes, + 0, + false) + } + + numRows += 1 + modificationsCount += 1 + } + + /** + * Creates an [[Iterator]] for the current rows in the array starting from a user provided index + * + * If there are subsequent [[add()]] or [[clear()]] calls made on this array after creation of + * the iterator, then the iterator is invalidated thus saving clients from thinking that they + * have read all the data while there were new rows added to this array. + */ + def generateIterator(startIndex: Int): Iterator[UnsafeRow] = { + if (startIndex < 0 || (numRows > 0 && startIndex > numRows)) { + throw new ArrayIndexOutOfBoundsException( + "Invalid `startIndex` provided for generating iterator over the array. " + + s"Total elements: $numRows, requested `startIndex`: $startIndex") + } + + if (spillableArray == null) { + new InMemoryBufferIterator(startIndex) + } else { + new SpillableArrayIterator(spillableArray.getIterator, numFieldsPerRow, startIndex) + } + } + + def generateIterator(): Iterator[UnsafeRow] = generateIterator(startIndex = 0) + + private[this] + abstract class ExternalAppendOnlyUnsafeRowArrayIterator extends Iterator[UnsafeRow] { + private val expectedModificationsCount = modificationsCount + + protected def isModified(): Boolean = expectedModificationsCount != modificationsCount + + protected def throwExceptionIfModified(): Unit = { + if (expectedModificationsCount != modificationsCount) { + throw new ConcurrentModificationException( + s"The backing ${classOf[ExternalAppendOnlyUnsafeRowArray].getName} has been modified " + + s"since the creation of this Iterator") + } + } + } + + private[this] class InMemoryBufferIterator(startIndex: Int) + extends ExternalAppendOnlyUnsafeRowArrayIterator { + + private var currentIndex = startIndex + + override def hasNext(): Boolean = !isModified() && currentIndex < numRows + + override def next(): UnsafeRow = { + throwExceptionIfModified() + val result = inMemoryBuffer(currentIndex) + currentIndex += 1 + result + } + } + + private[this] class SpillableArrayIterator( + iterator: UnsafeSorterIterator, + numFieldPerRow: Int, + startIndex: Int) + extends ExternalAppendOnlyUnsafeRowArrayIterator { + + private val currentRow = new UnsafeRow(numFieldPerRow) + + def init(): Unit = { + var i = 0 + while (i < startIndex) { + if (iterator.hasNext) { + iterator.loadNext() + } else { + throw new ArrayIndexOutOfBoundsException( + "Invalid `startIndex` provided for generating iterator over the array. " + + s"Total elements: $numRows, requested `startIndex`: $startIndex") + } + i += 1 + } + } + + // Traverse upto the given [[startIndex]] + init() + + override def hasNext(): Boolean = !isModified() && iterator.hasNext + + override def next(): UnsafeRow = { + throwExceptionIfModified() + iterator.loadNext() + currentRow.pointTo(iterator.getBaseObject, iterator.getBaseOffset, iterator.getRecordLength) + currentRow + } + } +} + +private[sql] object ExternalAppendOnlyUnsafeRowArray { + val DefaultInitialSizeOfInMemoryBuffer = 128 +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala index 8341fe2ffd07..f38098695131 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala @@ -19,65 +19,39 @@ package org.apache.spark.sql.execution.joins import org.apache.spark._ import org.apache.spark.rdd.{CartesianPartition, CartesianRDD, RDD} -import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, JoinedRow, UnsafeRow} import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeRowJoiner -import org.apache.spark.sql.execution.{BinaryExecNode, SparkPlan} +import org.apache.spark.sql.execution.{BinaryExecNode, ExternalAppendOnlyUnsafeRowArray, SparkPlan} import org.apache.spark.sql.execution.metric.SQLMetrics -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.CompletionIterator -import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter /** * An optimized CartesianRDD for UnsafeRow, which will cache the rows from second child RDD, * will be much faster than building the right partition for every row in left RDD, it also * materialize the right RDD (in case of the right RDD is nondeterministic). */ -class UnsafeCartesianRDD(left : RDD[UnsafeRow], right : RDD[UnsafeRow], numFieldsOfRight: Int) +class UnsafeCartesianRDD( + left : RDD[UnsafeRow], + right : RDD[UnsafeRow], + numFieldsOfRight: Int, + spillThreshold: Int) extends CartesianRDD[UnsafeRow, UnsafeRow](left.sparkContext, left, right) { override def compute(split: Partition, context: TaskContext): Iterator[(UnsafeRow, UnsafeRow)] = { - // We will not sort the rows, so prefixComparator and recordComparator are null. - val sorter = UnsafeExternalSorter.create( - context.taskMemoryManager(), - SparkEnv.get.blockManager, - SparkEnv.get.serializerManager, - context, - null, - null, - 1024, - SparkEnv.get.memoryManager.pageSizeBytes, - SparkEnv.get.conf.getLong("spark.shuffle.spill.numElementsForceSpillThreshold", - UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD), - false) + val rowArray = new ExternalAppendOnlyUnsafeRowArray(spillThreshold) val partition = split.asInstanceOf[CartesianPartition] - for (y <- rdd2.iterator(partition.s2, context)) { - sorter.insertRecord(y.getBaseObject, y.getBaseOffset, y.getSizeInBytes, 0, false) - } + rdd2.iterator(partition.s2, context).foreach(rowArray.add) - // Create an iterator from sorter and wrapper it as Iterator[UnsafeRow] - def createIter(): Iterator[UnsafeRow] = { - val iter = sorter.getIterator - val unsafeRow = new UnsafeRow(numFieldsOfRight) - new Iterator[UnsafeRow] { - override def hasNext: Boolean = { - iter.hasNext - } - override def next(): UnsafeRow = { - iter.loadNext() - unsafeRow.pointTo(iter.getBaseObject, iter.getBaseOffset, iter.getRecordLength) - unsafeRow - } - } - } + // Create an iterator from rowArray + def createIter(): Iterator[UnsafeRow] = rowArray.generateIterator() val resultIter = for (x <- rdd1.iterator(partition.s1, context); y <- createIter()) yield (x, y) CompletionIterator[(UnsafeRow, UnsafeRow), Iterator[(UnsafeRow, UnsafeRow)]]( - resultIter, sorter.cleanupResources()) + resultIter, rowArray.clear()) } } @@ -97,7 +71,9 @@ case class CartesianProductExec( val leftResults = left.execute().asInstanceOf[RDD[UnsafeRow]] val rightResults = right.execute().asInstanceOf[RDD[UnsafeRow]] - val pair = new UnsafeCartesianRDD(leftResults, rightResults, right.output.size) + val spillThreshold = sqlContext.conf.cartesianProductExecBufferSpillThreshold + + val pair = new UnsafeCartesianRDD(leftResults, rightResults, right.output.size, spillThreshold) pair.mapPartitionsWithIndexInternal { (index, iter) => val joiner = GenerateUnsafeRowJoiner.create(left.schema, right.schema) val filtered = if (condition.isDefined) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index ca9c0ed8cec3..bcdc4dcdf7d9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -25,7 +25,8 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.execution.{BinaryExecNode, CodegenSupport, RowIterator, SparkPlan} +import org.apache.spark.sql.execution.{BinaryExecNode, CodegenSupport, +ExternalAppendOnlyUnsafeRowArray, RowIterator, SparkPlan} import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.util.collection.BitSet @@ -95,9 +96,13 @@ case class SortMergeJoinExec( private def createRightKeyGenerator(): Projection = UnsafeProjection.create(rightKeys, right.output) + private def getSpillThreshold: Int = { + sqlContext.conf.sortMergeJoinExecBufferSpillThreshold + } + protected override def doExecute(): RDD[InternalRow] = { val numOutputRows = longMetric("numOutputRows") - + val spillThreshold = getSpillThreshold left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) => val boundCondition: (InternalRow) => Boolean = { condition.map { cond => @@ -115,39 +120,39 @@ case class SortMergeJoinExec( case _: InnerLike => new RowIterator { private[this] var currentLeftRow: InternalRow = _ - private[this] var currentRightMatches: ArrayBuffer[InternalRow] = _ - private[this] var currentMatchIdx: Int = -1 + private[this] var currentRightMatches: ExternalAppendOnlyUnsafeRowArray = _ + private[this] var rightMatchesIterator: Iterator[UnsafeRow] = null private[this] val smjScanner = new SortMergeJoinScanner( createLeftKeyGenerator(), createRightKeyGenerator(), keyOrdering, RowIterator.fromScala(leftIter), - RowIterator.fromScala(rightIter) + RowIterator.fromScala(rightIter), + spillThreshold ) private[this] val joinRow = new JoinedRow if (smjScanner.findNextInnerJoinRows()) { currentRightMatches = smjScanner.getBufferedMatches currentLeftRow = smjScanner.getStreamedRow - currentMatchIdx = 0 + rightMatchesIterator = currentRightMatches.generateIterator() } override def advanceNext(): Boolean = { - while (currentMatchIdx >= 0) { - if (currentMatchIdx == currentRightMatches.length) { + while (rightMatchesIterator != null) { + if (!rightMatchesIterator.hasNext) { if (smjScanner.findNextInnerJoinRows()) { currentRightMatches = smjScanner.getBufferedMatches currentLeftRow = smjScanner.getStreamedRow - currentMatchIdx = 0 + rightMatchesIterator = currentRightMatches.generateIterator() } else { currentRightMatches = null currentLeftRow = null - currentMatchIdx = -1 + rightMatchesIterator = null return false } } - joinRow(currentLeftRow, currentRightMatches(currentMatchIdx)) - currentMatchIdx += 1 + joinRow(currentLeftRow, rightMatchesIterator.next()) if (boundCondition(joinRow)) { numOutputRows += 1 return true @@ -165,7 +170,8 @@ case class SortMergeJoinExec( bufferedKeyGenerator = createRightKeyGenerator(), keyOrdering, streamedIter = RowIterator.fromScala(leftIter), - bufferedIter = RowIterator.fromScala(rightIter) + bufferedIter = RowIterator.fromScala(rightIter), + spillThreshold ) val rightNullRow = new GenericInternalRow(right.output.length) new LeftOuterIterator( @@ -177,7 +183,8 @@ case class SortMergeJoinExec( bufferedKeyGenerator = createLeftKeyGenerator(), keyOrdering, streamedIter = RowIterator.fromScala(rightIter), - bufferedIter = RowIterator.fromScala(leftIter) + bufferedIter = RowIterator.fromScala(leftIter), + spillThreshold ) val leftNullRow = new GenericInternalRow(left.output.length) new RightOuterIterator( @@ -209,7 +216,8 @@ case class SortMergeJoinExec( createRightKeyGenerator(), keyOrdering, RowIterator.fromScala(leftIter), - RowIterator.fromScala(rightIter) + RowIterator.fromScala(rightIter), + spillThreshold ) private[this] val joinRow = new JoinedRow @@ -217,14 +225,15 @@ case class SortMergeJoinExec( while (smjScanner.findNextInnerJoinRows()) { val currentRightMatches = smjScanner.getBufferedMatches currentLeftRow = smjScanner.getStreamedRow - var i = 0 - while (i < currentRightMatches.length) { - joinRow(currentLeftRow, currentRightMatches(i)) - if (boundCondition(joinRow)) { - numOutputRows += 1 - return true + if (currentRightMatches != null && currentRightMatches.length > 0) { + val rightMatchesIterator = currentRightMatches.generateIterator() + while (rightMatchesIterator.hasNext) { + joinRow(currentLeftRow, rightMatchesIterator.next()) + if (boundCondition(joinRow)) { + numOutputRows += 1 + return true + } } - i += 1 } } false @@ -241,7 +250,8 @@ case class SortMergeJoinExec( createRightKeyGenerator(), keyOrdering, RowIterator.fromScala(leftIter), - RowIterator.fromScala(rightIter) + RowIterator.fromScala(rightIter), + spillThreshold ) private[this] val joinRow = new JoinedRow @@ -249,17 +259,16 @@ case class SortMergeJoinExec( while (smjScanner.findNextOuterJoinRows()) { currentLeftRow = smjScanner.getStreamedRow val currentRightMatches = smjScanner.getBufferedMatches - if (currentRightMatches == null) { + if (currentRightMatches == null || currentRightMatches.length == 0) { return true } - var i = 0 var found = false - while (!found && i < currentRightMatches.length) { - joinRow(currentLeftRow, currentRightMatches(i)) + val rightMatchesIterator = currentRightMatches.generateIterator() + while (!found && rightMatchesIterator.hasNext) { + joinRow(currentLeftRow, rightMatchesIterator.next()) if (boundCondition(joinRow)) { found = true } - i += 1 } if (!found) { numOutputRows += 1 @@ -281,7 +290,8 @@ case class SortMergeJoinExec( createRightKeyGenerator(), keyOrdering, RowIterator.fromScala(leftIter), - RowIterator.fromScala(rightIter) + RowIterator.fromScala(rightIter), + spillThreshold ) private[this] val joinRow = new JoinedRow @@ -290,14 +300,13 @@ case class SortMergeJoinExec( currentLeftRow = smjScanner.getStreamedRow val currentRightMatches = smjScanner.getBufferedMatches var found = false - if (currentRightMatches != null) { - var i = 0 - while (!found && i < currentRightMatches.length) { - joinRow(currentLeftRow, currentRightMatches(i)) + if (currentRightMatches != null && currentRightMatches.length > 0) { + val rightMatchesIterator = currentRightMatches.generateIterator() + while (!found && rightMatchesIterator.hasNext) { + joinRow(currentLeftRow, rightMatchesIterator.next()) if (boundCondition(joinRow)) { found = true } - i += 1 } } result.setBoolean(0, found) @@ -376,8 +385,11 @@ case class SortMergeJoinExec( // A list to hold all matched rows from right side. val matches = ctx.freshName("matches") - val clsName = classOf[java.util.ArrayList[InternalRow]].getName - ctx.addMutableState(clsName, matches, s"$matches = new $clsName();") + val clsName = classOf[ExternalAppendOnlyUnsafeRowArray].getName + + val spillThreshold = getSpillThreshold + + ctx.addMutableState(clsName, matches, s"$matches = new $clsName($spillThreshold);") // Copy the left keys as class members so they could be used in next function call. val matchedKeyVars = copyKeys(ctx, leftKeyVars) @@ -428,7 +440,7 @@ case class SortMergeJoinExec( | } | $leftRow = null; | } else { - | $matches.add($rightRow.copy()); + | $matches.add((UnsafeRow) $rightRow); | $rightRow = null;; | } | } while ($leftRow != null); @@ -517,8 +529,7 @@ case class SortMergeJoinExec( val rightRow = ctx.freshName("rightRow") val rightVars = createRightVar(ctx, rightRow) - val size = ctx.freshName("size") - val i = ctx.freshName("i") + val iterator = ctx.freshName("iterator") val numOutput = metricTerm(ctx, "numOutputRows") val (beforeLoop, condCheck) = if (condition.isDefined) { // Split the code of creating variables based on whether it's used by condition or not. @@ -551,10 +562,10 @@ case class SortMergeJoinExec( s""" |while (findNextInnerJoinRows($leftInput, $rightInput)) { - | int $size = $matches.size(); | ${beforeLoop.trim} - | for (int $i = 0; $i < $size; $i ++) { - | InternalRow $rightRow = (InternalRow) $matches.get($i); + | scala.collection.Iterator $iterator = $matches.generateIterator(); + | while ($iterator.hasNext()) { + | InternalRow $rightRow = (InternalRow) $iterator.next(); | ${condCheck.trim} | $numOutput.add(1); | ${consume(ctx, leftVars ++ rightVars)} @@ -589,7 +600,8 @@ private[joins] class SortMergeJoinScanner( bufferedKeyGenerator: Projection, keyOrdering: Ordering[InternalRow], streamedIter: RowIterator, - bufferedIter: RowIterator) { + bufferedIter: RowIterator, + bufferThreshold: Int) { private[this] var streamedRow: InternalRow = _ private[this] var streamedRowKey: InternalRow = _ private[this] var bufferedRow: InternalRow = _ @@ -600,7 +612,7 @@ private[joins] class SortMergeJoinScanner( */ private[this] var matchJoinKey: InternalRow = _ /** Buffered rows from the buffered side of the join. This is empty if there are no matches. */ - private[this] val bufferedMatches: ArrayBuffer[InternalRow] = new ArrayBuffer[InternalRow] + private[this] val bufferedMatches = new ExternalAppendOnlyUnsafeRowArray(bufferThreshold) // Initialization (note: do _not_ want to advance streamed here). advancedBufferedToRowWithNullFreeJoinKey() @@ -609,7 +621,7 @@ private[joins] class SortMergeJoinScanner( def getStreamedRow: InternalRow = streamedRow - def getBufferedMatches: ArrayBuffer[InternalRow] = bufferedMatches + def getBufferedMatches: ExternalAppendOnlyUnsafeRowArray = bufferedMatches /** * Advances both input iterators, stopping when we have found rows with matching join keys. @@ -755,7 +767,7 @@ private[joins] class SortMergeJoinScanner( matchJoinKey = streamedRowKey.copy() bufferedMatches.clear() do { - bufferedMatches += bufferedRow.copy() // need to copy mutable rows before buffering them + bufferedMatches.add(bufferedRow.asInstanceOf[UnsafeRow]) advancedBufferedToRowWithNullFreeJoinKey() } while (bufferedRow != null && keyOrdering.compare(streamedRowKey, bufferedRowKey) == 0) } @@ -819,7 +831,7 @@ private abstract class OneSideOuterIterator( protected[this] val joinedRow: JoinedRow = new JoinedRow() // Index of the buffered rows, reset to 0 whenever we advance to a new streamed row - private[this] var bufferIndex: Int = 0 + private[this] var rightMatchesIterator: Iterator[UnsafeRow] = null // This iterator is initialized lazily so there should be no matches initially assert(smjScanner.getBufferedMatches.length == 0) @@ -833,7 +845,7 @@ private abstract class OneSideOuterIterator( * @return whether there are more rows in the stream to consume. */ private def advanceStream(): Boolean = { - bufferIndex = 0 + rightMatchesIterator = null if (smjScanner.findNextOuterJoinRows()) { setStreamSideOutput(smjScanner.getStreamedRow) if (smjScanner.getBufferedMatches.isEmpty) { @@ -858,10 +870,13 @@ private abstract class OneSideOuterIterator( */ private def advanceBufferUntilBoundConditionSatisfied(): Boolean = { var foundMatch: Boolean = false - while (!foundMatch && bufferIndex < smjScanner.getBufferedMatches.length) { - setBufferedSideOutput(smjScanner.getBufferedMatches(bufferIndex)) + if (rightMatchesIterator == null) { + rightMatchesIterator = smjScanner.getBufferedMatches.generateIterator() + } + + while (!foundMatch && rightMatchesIterator.hasNext) { + setBufferedSideOutput(rightMatchesIterator.next()) foundMatch = boundCondition(joinedRow) - bufferIndex += 1 } foundMatch } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/RowBuffer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/RowBuffer.scala deleted file mode 100644 index ee36c8425151..000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/RowBuffer.scala +++ /dev/null @@ -1,115 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.window - -import scala.collection.mutable.ArrayBuffer - -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.UnsafeRow -import org.apache.spark.util.collection.unsafe.sort.{UnsafeExternalSorter, UnsafeSorterIterator} - - -/** - * The interface of row buffer for a partition. In absence of a buffer pool (with locking), the - * row buffer is used to materialize a partition of rows since we need to repeatedly scan these - * rows in window function processing. - */ -private[window] abstract class RowBuffer { - - /** Number of rows. */ - def size: Int - - /** Return next row in the buffer, null if no more left. */ - def next(): InternalRow - - /** Skip the next `n` rows. */ - def skip(n: Int): Unit - - /** Return a new RowBuffer that has the same rows. */ - def copy(): RowBuffer -} - -/** - * A row buffer based on ArrayBuffer (the number of rows is limited). - */ -private[window] class ArrayRowBuffer(buffer: ArrayBuffer[UnsafeRow]) extends RowBuffer { - - private[this] var cursor: Int = -1 - - /** Number of rows. */ - override def size: Int = buffer.length - - /** Return next row in the buffer, null if no more left. */ - override def next(): InternalRow = { - cursor += 1 - if (cursor < buffer.length) { - buffer(cursor) - } else { - null - } - } - - /** Skip the next `n` rows. */ - override def skip(n: Int): Unit = { - cursor += n - } - - /** Return a new RowBuffer that has the same rows. */ - override def copy(): RowBuffer = { - new ArrayRowBuffer(buffer) - } -} - -/** - * An external buffer of rows based on UnsafeExternalSorter. - */ -private[window] class ExternalRowBuffer(sorter: UnsafeExternalSorter, numFields: Int) - extends RowBuffer { - - private[this] val iter: UnsafeSorterIterator = sorter.getIterator - - private[this] val currentRow = new UnsafeRow(numFields) - - /** Number of rows. */ - override def size: Int = iter.getNumRecords() - - /** Return next row in the buffer, null if no more left. */ - override def next(): InternalRow = { - if (iter.hasNext) { - iter.loadNext() - currentRow.pointTo(iter.getBaseObject, iter.getBaseOffset, iter.getRecordLength) - currentRow - } else { - null - } - } - - /** Skip the next `n` rows. */ - override def skip(n: Int): Unit = { - var i = 0 - while (i < n && iter.hasNext) { - iter.loadNext() - i += 1 - } - } - - /** Return a new RowBuffer that has the same rows. */ - override def copy(): RowBuffer = { - new ExternalRowBuffer(sorter, numFields) - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala index 80b87d5ffa79..950a6794a74a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala @@ -20,15 +20,13 @@ package org.apache.spark.sql.execution.window import scala.collection.mutable import scala.collection.mutable.ArrayBuffer -import org.apache.spark.{SparkEnv, TaskContext} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} +import org.apache.spark.sql.execution.{ExternalAppendOnlyUnsafeRowArray, SparkPlan, UnaryExecNode} import org.apache.spark.sql.types.IntegerType -import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter /** * This class calculates and outputs (windowed) aggregates over the rows in a single (sorted) @@ -284,6 +282,7 @@ case class WindowExec( // Unwrap the expressions and factories from the map. val expressions = windowFrameExpressionFactoryPairs.flatMap(_._1) val factories = windowFrameExpressionFactoryPairs.map(_._2).toArray + val spillThreshold = sqlContext.conf.windowExecBufferSpillThreshold // Start processing. child.execute().mapPartitions { stream => @@ -310,10 +309,12 @@ case class WindowExec( fetchNextRow() // Manage the current partition. - val rows = ArrayBuffer.empty[UnsafeRow] val inputFields = child.output.length - var sorter: UnsafeExternalSorter = null - var rowBuffer: RowBuffer = null + + val buffer: ExternalAppendOnlyUnsafeRowArray = + new ExternalAppendOnlyUnsafeRowArray(spillThreshold) + var bufferIterator: Iterator[UnsafeRow] = _ + val windowFunctionResult = new SpecificInternalRow(expressions.map(_.dataType)) val frames = factories.map(_(windowFunctionResult)) val numFrames = frames.length @@ -323,78 +324,43 @@ case class WindowExec( val currentGroup = nextGroup.copy() // clear last partition - if (sorter != null) { - // the last sorter of this task will be cleaned up via task completion listener - sorter.cleanupResources() - sorter = null - } else { - rows.clear() - } + buffer.clear() while (nextRowAvailable && nextGroup == currentGroup) { - if (sorter == null) { - rows += nextRow.copy() - - if (rows.length >= 4096) { - // We will not sort the rows, so prefixComparator and recordComparator are null. - sorter = UnsafeExternalSorter.create( - TaskContext.get().taskMemoryManager(), - SparkEnv.get.blockManager, - SparkEnv.get.serializerManager, - TaskContext.get(), - null, - null, - 1024, - SparkEnv.get.memoryManager.pageSizeBytes, - SparkEnv.get.conf.getLong("spark.shuffle.spill.numElementsForceSpillThreshold", - UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD), - false) - rows.foreach { r => - sorter.insertRecord(r.getBaseObject, r.getBaseOffset, r.getSizeInBytes, 0, false) - } - rows.clear() - } - } else { - sorter.insertRecord(nextRow.getBaseObject, nextRow.getBaseOffset, - nextRow.getSizeInBytes, 0, false) - } + buffer.add(nextRow) fetchNextRow() } - if (sorter != null) { - rowBuffer = new ExternalRowBuffer(sorter, inputFields) - } else { - rowBuffer = new ArrayRowBuffer(rows) - } // Setup the frames. var i = 0 while (i < numFrames) { - frames(i).prepare(rowBuffer.copy()) + frames(i).prepare(buffer) i += 1 } // Setup iteration rowIndex = 0 - rowsSize = rowBuffer.size + bufferIterator = buffer.generateIterator() } // Iteration var rowIndex = 0 - var rowsSize = 0L - override final def hasNext: Boolean = rowIndex < rowsSize || nextRowAvailable + override final def hasNext: Boolean = + (bufferIterator != null && bufferIterator.hasNext) || nextRowAvailable val join = new JoinedRow override final def next(): InternalRow = { // Load the next partition if we need to. - if (rowIndex >= rowsSize && nextRowAvailable) { + if ((bufferIterator == null || !bufferIterator.hasNext) && nextRowAvailable) { fetchNextPartition() } - if (rowIndex < rowsSize) { + if (bufferIterator.hasNext) { + val current = bufferIterator.next() + // Get the results for the window frames. var i = 0 - val current = rowBuffer.next() while (i < numFrames) { frames(i).write(rowIndex, current) i += 1 @@ -406,7 +372,9 @@ case class WindowExec( // Return the projection. result(join) - } else throw new NoSuchElementException + } else { + throw new NoSuchElementException + } } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala index 70efc0f78ddb..af2b4fb92062 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala @@ -22,6 +22,7 @@ import java.util import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp +import org.apache.spark.sql.execution.ExternalAppendOnlyUnsafeRowArray /** @@ -35,7 +36,7 @@ private[window] abstract class WindowFunctionFrame { * * @param rows to calculate the frame results for. */ - def prepare(rows: RowBuffer): Unit + def prepare(rows: ExternalAppendOnlyUnsafeRowArray): Unit /** * Write the current results to the target row. @@ -43,6 +44,12 @@ private[window] abstract class WindowFunctionFrame { def write(index: Int, current: InternalRow): Unit } +object WindowFunctionFrame { + def getNextOrNull(iterator: Iterator[UnsafeRow]): UnsafeRow = { + if (iterator.hasNext) iterator.next() else null + } +} + /** * The offset window frame calculates frames containing LEAD/LAG statements. * @@ -65,7 +72,12 @@ private[window] final class OffsetWindowFunctionFrame( extends WindowFunctionFrame { /** Rows of the partition currently being processed. */ - private[this] var input: RowBuffer = null + private[this] var input: ExternalAppendOnlyUnsafeRowArray = null + + /** + * An iterator over the [[input]] + */ + private[this] var inputIterator: Iterator[UnsafeRow] = _ /** Index of the input row currently used for output. */ private[this] var inputIndex = 0 @@ -103,20 +115,21 @@ private[window] final class OffsetWindowFunctionFrame( newMutableProjection(boundExpressions, Nil).target(target) } - override def prepare(rows: RowBuffer): Unit = { + override def prepare(rows: ExternalAppendOnlyUnsafeRowArray): Unit = { input = rows + inputIterator = input.generateIterator() // drain the first few rows if offset is larger than zero inputIndex = 0 while (inputIndex < offset) { - input.next() + if (inputIterator.hasNext) inputIterator.next() inputIndex += 1 } inputIndex = offset } override def write(index: Int, current: InternalRow): Unit = { - if (inputIndex >= 0 && inputIndex < input.size) { - val r = input.next() + if (inputIndex >= 0 && inputIndex < input.length) { + val r = WindowFunctionFrame.getNextOrNull(inputIterator) projection(r) } else { // Use default values since the offset row does not exist. @@ -143,7 +156,12 @@ private[window] final class SlidingWindowFunctionFrame( extends WindowFunctionFrame { /** Rows of the partition currently being processed. */ - private[this] var input: RowBuffer = null + private[this] var input: ExternalAppendOnlyUnsafeRowArray = null + + /** + * An iterator over the [[input]] + */ + private[this] var inputIterator: Iterator[UnsafeRow] = _ /** The next row from `input`. */ private[this] var nextRow: InternalRow = null @@ -164,9 +182,10 @@ private[window] final class SlidingWindowFunctionFrame( private[this] var inputLowIndex = 0 /** Prepare the frame for calculating a new partition. Reset all variables. */ - override def prepare(rows: RowBuffer): Unit = { + override def prepare(rows: ExternalAppendOnlyUnsafeRowArray): Unit = { input = rows - nextRow = rows.next() + inputIterator = input.generateIterator() + nextRow = WindowFunctionFrame.getNextOrNull(inputIterator) inputHighIndex = 0 inputLowIndex = 0 buffer.clear() @@ -180,7 +199,7 @@ private[window] final class SlidingWindowFunctionFrame( // the output row upper bound. while (nextRow != null && ubound.compare(nextRow, inputHighIndex, current, index) <= 0) { buffer.add(nextRow.copy()) - nextRow = input.next() + nextRow = WindowFunctionFrame.getNextOrNull(inputIterator) inputHighIndex += 1 bufferUpdated = true } @@ -195,7 +214,7 @@ private[window] final class SlidingWindowFunctionFrame( // Only recalculate and update when the buffer changes. if (bufferUpdated) { - processor.initialize(input.size) + processor.initialize(input.length) val iter = buffer.iterator() while (iter.hasNext) { processor.update(iter.next()) @@ -222,13 +241,12 @@ private[window] final class UnboundedWindowFunctionFrame( extends WindowFunctionFrame { /** Prepare the frame for calculating a new partition. Process all rows eagerly. */ - override def prepare(rows: RowBuffer): Unit = { - val size = rows.size - processor.initialize(size) - var i = 0 - while (i < size) { - processor.update(rows.next()) - i += 1 + override def prepare(rows: ExternalAppendOnlyUnsafeRowArray): Unit = { + processor.initialize(rows.length) + + val iterator = rows.generateIterator() + while (iterator.hasNext) { + processor.update(iterator.next()) } } @@ -261,7 +279,12 @@ private[window] final class UnboundedPrecedingWindowFunctionFrame( extends WindowFunctionFrame { /** Rows of the partition currently being processed. */ - private[this] var input: RowBuffer = null + private[this] var input: ExternalAppendOnlyUnsafeRowArray = null + + /** + * An iterator over the [[input]] + */ + private[this] var inputIterator: Iterator[UnsafeRow] = _ /** The next row from `input`. */ private[this] var nextRow: InternalRow = null @@ -273,11 +296,15 @@ private[window] final class UnboundedPrecedingWindowFunctionFrame( private[this] var inputIndex = 0 /** Prepare the frame for calculating a new partition. */ - override def prepare(rows: RowBuffer): Unit = { + override def prepare(rows: ExternalAppendOnlyUnsafeRowArray): Unit = { input = rows - nextRow = rows.next() inputIndex = 0 - processor.initialize(input.size) + inputIterator = input.generateIterator() + if (inputIterator.hasNext) { + nextRow = inputIterator.next() + } + + processor.initialize(input.length) } /** Write the frame columns for the current row to the given target row. */ @@ -288,7 +315,7 @@ private[window] final class UnboundedPrecedingWindowFunctionFrame( // the output row upper bound. while (nextRow != null && ubound.compare(nextRow, inputIndex, current, index) <= 0) { processor.update(nextRow) - nextRow = input.next() + nextRow = WindowFunctionFrame.getNextOrNull(inputIterator) inputIndex += 1 bufferUpdated = true } @@ -323,7 +350,7 @@ private[window] final class UnboundedFollowingWindowFunctionFrame( extends WindowFunctionFrame { /** Rows of the partition currently being processed. */ - private[this] var input: RowBuffer = null + private[this] var input: ExternalAppendOnlyUnsafeRowArray = null /** * Index of the first input row with a value equal to or greater than the lower bound of the @@ -332,7 +359,7 @@ private[window] final class UnboundedFollowingWindowFunctionFrame( private[this] var inputIndex = 0 /** Prepare the frame for calculating a new partition. */ - override def prepare(rows: RowBuffer): Unit = { + override def prepare(rows: ExternalAppendOnlyUnsafeRowArray): Unit = { input = rows inputIndex = 0 } @@ -341,25 +368,25 @@ private[window] final class UnboundedFollowingWindowFunctionFrame( override def write(index: Int, current: InternalRow): Unit = { var bufferUpdated = index == 0 - // Duplicate the input to have a new iterator - val tmp = input.copy() - - // Drop all rows from the buffer for which the input row value is smaller than + // Ignore all the rows from the buffer for which the input row value is smaller than // the output row lower bound. - tmp.skip(inputIndex) - var nextRow = tmp.next() + val iterator = input.generateIterator(startIndex = inputIndex) + + var nextRow = WindowFunctionFrame.getNextOrNull(iterator) while (nextRow != null && lbound.compare(nextRow, inputIndex, current, index) < 0) { - nextRow = tmp.next() inputIndex += 1 bufferUpdated = true + nextRow = WindowFunctionFrame.getNextOrNull(iterator) } // Only recalculate and update when the buffer changes. if (bufferUpdated) { - processor.initialize(input.size) - while (nextRow != null) { + processor.initialize(input.length) + if (nextRow != null) { processor.update(nextRow) - nextRow = tmp.next() + } + while (iterator.hasNext) { + processor.update(iterator.next()) } processor.evaluate(target) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 2e006735d123..1a66aa85f5a0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql +import scala.collection.mutable.ListBuffer import scala.language.existentials import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation @@ -24,7 +25,7 @@ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.execution.joins._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext - +import org.apache.spark.TestUtils.{assertNotSpilled, assertSpilled} class JoinSuite extends QueryTest with SharedSQLContext { import testImplicits._ @@ -604,4 +605,137 @@ class JoinSuite extends QueryTest with SharedSQLContext { cartesianQueries.foreach(checkCartesianDetection) } + + test("test SortMergeJoin (without spill)") { + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1", + "spark.sql.sortMergeJoinExec.buffer.spill.threshold" -> Int.MaxValue.toString) { + + assertNotSpilled(sparkContext, "inner join") { + checkAnswer( + sql("SELECT * FROM testData JOIN testData2 ON key = a where key = 2"), + Row(2, "2", 2, 1) :: Row(2, "2", 2, 2) :: Nil + ) + } + + val expected = new ListBuffer[Row]() + expected.append( + Row(1, "1", 1, 1), Row(1, "1", 1, 2), + Row(2, "2", 2, 1), Row(2, "2", 2, 2), + Row(3, "3", 3, 1), Row(3, "3", 3, 2) + ) + for (i <- 4 to 100) { + expected.append(Row(i, i.toString, null, null)) + } + + assertNotSpilled(sparkContext, "left outer join") { + checkAnswer( + sql( + """ + |SELECT + | big.key, big.value, small.a, small.b + |FROM + | testData big + |LEFT OUTER JOIN + | testData2 small + |ON + | big.key = small.a + """.stripMargin), + expected + ) + } + + assertNotSpilled(sparkContext, "right outer join") { + checkAnswer( + sql( + """ + |SELECT + | big.key, big.value, small.a, small.b + |FROM + | testData2 small + |RIGHT OUTER JOIN + | testData big + |ON + | big.key = small.a + """.stripMargin), + expected + ) + } + } + } + + test("test SortMergeJoin (with spill)") { + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1", + "spark.sql.sortMergeJoinExec.buffer.spill.threshold" -> "0") { + + assertSpilled(sparkContext, "inner join") { + checkAnswer( + sql("SELECT * FROM testData JOIN testData2 ON key = a where key = 2"), + Row(2, "2", 2, 1) :: Row(2, "2", 2, 2) :: Nil + ) + } + + val expected = new ListBuffer[Row]() + expected.append( + Row(1, "1", 1, 1), Row(1, "1", 1, 2), + Row(2, "2", 2, 1), Row(2, "2", 2, 2), + Row(3, "3", 3, 1), Row(3, "3", 3, 2) + ) + for (i <- 4 to 100) { + expected.append(Row(i, i.toString, null, null)) + } + + assertSpilled(sparkContext, "left outer join") { + checkAnswer( + sql( + """ + |SELECT + | big.key, big.value, small.a, small.b + |FROM + | testData big + |LEFT OUTER JOIN + | testData2 small + |ON + | big.key = small.a + """.stripMargin), + expected + ) + } + + assertSpilled(sparkContext, "right outer join") { + checkAnswer( + sql( + """ + |SELECT + | big.key, big.value, small.a, small.b + |FROM + | testData2 small + |RIGHT OUTER JOIN + | testData big + |ON + | big.key = small.a + """.stripMargin), + expected + ) + } + + // FULL OUTER JOIN still does not use [[ExternalAppendOnlyUnsafeRowArray]] + // so should not cause any spill + assertNotSpilled(sparkContext, "full outer join") { + checkAnswer( + sql( + """ + |SELECT + | big.key, big.value, small.a, small.b + |FROM + | testData2 small + |FULL OUTER JOIN + | testData big + |ON + | big.key = small.a + """.stripMargin), + expected + ) + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArrayBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArrayBenchmark.scala new file mode 100644 index 000000000000..00c5f2550cbb --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArrayBenchmark.scala @@ -0,0 +1,233 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.{SparkConf, SparkContext, SparkEnv, TaskContext} +import org.apache.spark.memory.MemoryTestingUtils +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.util.Benchmark +import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter + +object ExternalAppendOnlyUnsafeRowArrayBenchmark { + + def testAgainstRawArrayBuffer(numSpillThreshold: Int, numRows: Int, iterations: Int): Unit = { + val random = new java.util.Random() + val rows = (1 to numRows).map(_ => { + val row = new UnsafeRow(1) + row.pointTo(new Array[Byte](64), 16) + row.setLong(0, random.nextLong()) + row + }) + + val benchmark = new Benchmark(s"Array with $numRows rows", iterations * numRows) + + // Internally, `ExternalAppendOnlyUnsafeRowArray` will create an + // in-memory buffer of size `numSpillThreshold`. This will mimic that + val initialSize = + Math.min( + ExternalAppendOnlyUnsafeRowArray.DefaultInitialSizeOfInMemoryBuffer, + numSpillThreshold) + + benchmark.addCase("ArrayBuffer") { _: Int => + var sum = 0L + for (_ <- 0L until iterations) { + val array = new ArrayBuffer[UnsafeRow](initialSize) + + // Internally, `ExternalAppendOnlyUnsafeRowArray` will create a + // copy of the row. This will mimic that + rows.foreach(x => array += x.copy()) + + var i = 0 + val n = array.length + while (i < n) { + sum = sum + array(i).getLong(0) + i += 1 + } + array.clear() + } + } + + benchmark.addCase("ExternalAppendOnlyUnsafeRowArray") { _: Int => + var sum = 0L + for (_ <- 0L until iterations) { + val array = new ExternalAppendOnlyUnsafeRowArray(numSpillThreshold) + rows.foreach(x => array.add(x)) + + val iterator = array.generateIterator() + while (iterator.hasNext) { + sum = sum + iterator.next().getLong(0) + } + array.clear() + } + } + + val conf = new SparkConf(false) + // Make the Java serializer write a reset instruction (TC_RESET) after each object to test + // for a bug we had with bytes written past the last object in a batch (SPARK-2792) + conf.set("spark.serializer.objectStreamReset", "1") + conf.set("spark.serializer", "org.apache.spark.serializer.JavaSerializer") + + val sc = new SparkContext("local", "test", conf) + val taskContext = MemoryTestingUtils.fakeTaskContext(SparkEnv.get) + TaskContext.setTaskContext(taskContext) + benchmark.run() + sc.stop() + } + + def testAgainstRawUnsafeExternalSorter( + numSpillThreshold: Int, + numRows: Int, + iterations: Int): Unit = { + + val random = new java.util.Random() + val rows = (1 to numRows).map(_ => { + val row = new UnsafeRow(1) + row.pointTo(new Array[Byte](64), 16) + row.setLong(0, random.nextLong()) + row + }) + + val benchmark = new Benchmark(s"Spilling with $numRows rows", iterations * numRows) + + benchmark.addCase("UnsafeExternalSorter") { _: Int => + var sum = 0L + for (_ <- 0L until iterations) { + val array = UnsafeExternalSorter.create( + TaskContext.get().taskMemoryManager(), + SparkEnv.get.blockManager, + SparkEnv.get.serializerManager, + TaskContext.get(), + null, + null, + 1024, + SparkEnv.get.memoryManager.pageSizeBytes, + numSpillThreshold, + false) + + rows.foreach(x => + array.insertRecord( + x.getBaseObject, + x.getBaseOffset, + x.getSizeInBytes, + 0, + false)) + + val unsafeRow = new UnsafeRow(1) + val iter = array.getIterator + while (iter.hasNext) { + iter.loadNext() + unsafeRow.pointTo(iter.getBaseObject, iter.getBaseOffset, iter.getRecordLength) + sum = sum + unsafeRow.getLong(0) + } + array.cleanupResources() + } + } + + benchmark.addCase("ExternalAppendOnlyUnsafeRowArray") { _: Int => + var sum = 0L + for (_ <- 0L until iterations) { + val array = new ExternalAppendOnlyUnsafeRowArray(numSpillThreshold) + rows.foreach(x => array.add(x)) + + val iterator = array.generateIterator() + while (iterator.hasNext) { + sum = sum + iterator.next().getLong(0) + } + array.clear() + } + } + + val conf = new SparkConf(false) + // Make the Java serializer write a reset instruction (TC_RESET) after each object to test + // for a bug we had with bytes written past the last object in a batch (SPARK-2792) + conf.set("spark.serializer.objectStreamReset", "1") + conf.set("spark.serializer", "org.apache.spark.serializer.JavaSerializer") + + val sc = new SparkContext("local", "test", conf) + val taskContext = MemoryTestingUtils.fakeTaskContext(SparkEnv.get) + TaskContext.setTaskContext(taskContext) + benchmark.run() + sc.stop() + } + + def main(args: Array[String]): Unit = { + + // ========================================================================================= // + // WITHOUT SPILL + // ========================================================================================= // + + val spillThreshold = 100 * 1000 + + /* + Intel(R) Core(TM) i7-6920HQ CPU @ 2.90GHz + + Array with 1000 rows: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + ArrayBuffer 7821 / 7941 33.5 29.8 1.0X + ExternalAppendOnlyUnsafeRowArray 8798 / 8819 29.8 33.6 0.9X + */ + testAgainstRawArrayBuffer(spillThreshold, 1000, 1 << 18) + + /* + Intel(R) Core(TM) i7-6920HQ CPU @ 2.90GHz + + Array with 30000 rows: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + ArrayBuffer 19200 / 19206 25.6 39.1 1.0X + ExternalAppendOnlyUnsafeRowArray 19558 / 19562 25.1 39.8 1.0X + */ + testAgainstRawArrayBuffer(spillThreshold, 30 * 1000, 1 << 14) + + /* + Intel(R) Core(TM) i7-6920HQ CPU @ 2.90GHz + + Array with 100000 rows: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + ArrayBuffer 5949 / 6028 17.2 58.1 1.0X + ExternalAppendOnlyUnsafeRowArray 6078 / 6138 16.8 59.4 1.0X + */ + testAgainstRawArrayBuffer(spillThreshold, 100 * 1000, 1 << 10) + + // ========================================================================================= // + // WITH SPILL + // ========================================================================================= // + + /* + Intel(R) Core(TM) i7-6920HQ CPU @ 2.90GHz + + Spilling with 1000 rows: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + UnsafeExternalSorter 9239 / 9470 28.4 35.2 1.0X + ExternalAppendOnlyUnsafeRowArray 8857 / 8909 29.6 33.8 1.0X + */ + testAgainstRawUnsafeExternalSorter(100 * 1000, 1000, 1 << 18) + + /* + Intel(R) Core(TM) i7-6920HQ CPU @ 2.90GHz + + Spilling with 10000 rows: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + UnsafeExternalSorter 4 / 5 39.3 25.5 1.0X + ExternalAppendOnlyUnsafeRowArray 5 / 6 29.8 33.5 0.8X + */ + testAgainstRawUnsafeExternalSorter( + UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD.toInt, 10 * 1000, 1 << 4) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArraySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArraySuite.scala new file mode 100644 index 000000000000..53c41639942b --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArraySuite.scala @@ -0,0 +1,351 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import java.util.ConcurrentModificationException + +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark._ +import org.apache.spark.memory.MemoryTestingUtils +import org.apache.spark.sql.catalyst.expressions.UnsafeRow + +class ExternalAppendOnlyUnsafeRowArraySuite extends SparkFunSuite with LocalSparkContext { + private val random = new java.util.Random() + private var taskContext: TaskContext = _ + + override def afterAll(): Unit = TaskContext.unset() + + private def withExternalArray(spillThreshold: Int) + (f: ExternalAppendOnlyUnsafeRowArray => Unit): Unit = { + sc = new SparkContext("local", "test", new SparkConf(false)) + + taskContext = MemoryTestingUtils.fakeTaskContext(SparkEnv.get) + TaskContext.setTaskContext(taskContext) + + val array = new ExternalAppendOnlyUnsafeRowArray( + taskContext.taskMemoryManager(), + SparkEnv.get.blockManager, + SparkEnv.get.serializerManager, + taskContext, + 1024, + SparkEnv.get.memoryManager.pageSizeBytes, + spillThreshold) + try f(array) finally { + array.clear() + } + } + + private def insertRow(array: ExternalAppendOnlyUnsafeRowArray): Long = { + val valueInserted = random.nextLong() + + val row = new UnsafeRow(1) + row.pointTo(new Array[Byte](64), 16) + row.setLong(0, valueInserted) + array.add(row) + valueInserted + } + + private def checkIfValueExists(iterator: Iterator[UnsafeRow], expectedValue: Long): Unit = { + assert(iterator.hasNext) + val actualRow = iterator.next() + assert(actualRow.getLong(0) == expectedValue) + assert(actualRow.getSizeInBytes == 16) + } + + private def validateData( + array: ExternalAppendOnlyUnsafeRowArray, + expectedValues: ArrayBuffer[Long]): Iterator[UnsafeRow] = { + val iterator = array.generateIterator() + for (value <- expectedValues) { + checkIfValueExists(iterator, value) + } + + assert(!iterator.hasNext) + iterator + } + + private def populateRows( + array: ExternalAppendOnlyUnsafeRowArray, + numRowsToBePopulated: Int): ArrayBuffer[Long] = { + val populatedValues = new ArrayBuffer[Long] + populateRows(array, numRowsToBePopulated, populatedValues) + } + + private def populateRows( + array: ExternalAppendOnlyUnsafeRowArray, + numRowsToBePopulated: Int, + populatedValues: ArrayBuffer[Long]): ArrayBuffer[Long] = { + for (_ <- 0 until numRowsToBePopulated) { + populatedValues.append(insertRow(array)) + } + populatedValues + } + + private def getNumBytesSpilled: Long = { + TaskContext.get().taskMetrics().memoryBytesSpilled + } + + private def assertNoSpill(): Unit = { + assert(getNumBytesSpilled == 0) + } + + private def assertSpill(): Unit = { + assert(getNumBytesSpilled > 0) + } + + test("insert rows less than the spillThreshold") { + val spillThreshold = 100 + withExternalArray(spillThreshold) { array => + assert(array.isEmpty) + + val expectedValues = populateRows(array, 1) + assert(!array.isEmpty) + assert(array.length == 1) + + val iterator1 = validateData(array, expectedValues) + + // Add more rows (but not too many to trigger switch to [[UnsafeExternalSorter]]) + // Verify that NO spill has happened + populateRows(array, spillThreshold - 1, expectedValues) + assert(array.length == spillThreshold) + assertNoSpill() + + val iterator2 = validateData(array, expectedValues) + + assert(!iterator1.hasNext) + assert(!iterator2.hasNext) + } + } + + test("insert rows more than the spillThreshold to force spill") { + val spillThreshold = 100 + withExternalArray(spillThreshold) { array => + val numValuesInserted = 20 * spillThreshold + + assert(array.isEmpty) + val expectedValues = populateRows(array, 1) + assert(array.length == 1) + + val iterator1 = validateData(array, expectedValues) + + // Populate more rows to trigger spill. Verify that spill has happened + populateRows(array, numValuesInserted - 1, expectedValues) + assert(array.length == numValuesInserted) + assertSpill() + + val iterator2 = validateData(array, expectedValues) + assert(!iterator2.hasNext) + + assert(!iterator1.hasNext) + intercept[ConcurrentModificationException](iterator1.next()) + } + } + + test("iterator on an empty array should be empty") { + withExternalArray(spillThreshold = 10) { array => + val iterator = array.generateIterator() + assert(array.isEmpty) + assert(array.length == 0) + assert(!iterator.hasNext) + } + } + + test("generate iterator with negative start index") { + withExternalArray(spillThreshold = 2) { array => + val exception = + intercept[ArrayIndexOutOfBoundsException](array.generateIterator(startIndex = -10)) + + assert(exception.getMessage.contains( + "Invalid `startIndex` provided for generating iterator over the array") + ) + } + } + + test("generate iterator with start index exceeding array's size (without spill)") { + val spillThreshold = 2 + withExternalArray(spillThreshold) { array => + populateRows(array, spillThreshold / 2) + + val exception = + intercept[ArrayIndexOutOfBoundsException]( + array.generateIterator(startIndex = spillThreshold * 10)) + assert(exception.getMessage.contains( + "Invalid `startIndex` provided for generating iterator over the array")) + } + } + + test("generate iterator with start index exceeding array's size (with spill)") { + val spillThreshold = 2 + withExternalArray(spillThreshold) { array => + populateRows(array, spillThreshold * 2) + + val exception = + intercept[ArrayIndexOutOfBoundsException]( + array.generateIterator(startIndex = spillThreshold * 10)) + + assert(exception.getMessage.contains( + "Invalid `startIndex` provided for generating iterator over the array")) + } + } + + test("generate iterator with custom start index (without spill)") { + val spillThreshold = 10 + withExternalArray(spillThreshold) { array => + val expectedValues = populateRows(array, spillThreshold) + val startIndex = spillThreshold / 2 + val iterator = array.generateIterator(startIndex = startIndex) + for (i <- startIndex until expectedValues.length) { + checkIfValueExists(iterator, expectedValues(i)) + } + } + } + + test("generate iterator with custom start index (with spill)") { + val spillThreshold = 10 + withExternalArray(spillThreshold) { array => + val expectedValues = populateRows(array, spillThreshold * 10) + val startIndex = spillThreshold * 2 + val iterator = array.generateIterator(startIndex = startIndex) + for (i <- startIndex until expectedValues.length) { + checkIfValueExists(iterator, expectedValues(i)) + } + } + } + + test("test iterator invalidation (without spill)") { + withExternalArray(spillThreshold = 10) { array => + // insert 2 rows, iterate until the first row + populateRows(array, 2) + + var iterator = array.generateIterator() + assert(iterator.hasNext) + iterator.next() + + // Adding more row(s) should invalidate any old iterators + populateRows(array, 1) + assert(!iterator.hasNext) + intercept[ConcurrentModificationException](iterator.next()) + + // Clearing the array should also invalidate any old iterators + iterator = array.generateIterator() + assert(iterator.hasNext) + iterator.next() + + array.clear() + assert(!iterator.hasNext) + intercept[ConcurrentModificationException](iterator.next()) + } + } + + test("test iterator invalidation (with spill)") { + val spillThreshold = 10 + withExternalArray(spillThreshold) { array => + // Populate enough rows so that spill has happens + populateRows(array, spillThreshold * 2) + assertSpill() + + var iterator = array.generateIterator() + assert(iterator.hasNext) + iterator.next() + + // Adding more row(s) should invalidate any old iterators + populateRows(array, 1) + assert(!iterator.hasNext) + intercept[ConcurrentModificationException](iterator.next()) + + // Clearing the array should also invalidate any old iterators + iterator = array.generateIterator() + assert(iterator.hasNext) + iterator.next() + + array.clear() + assert(!iterator.hasNext) + intercept[ConcurrentModificationException](iterator.next()) + } + } + + test("clear on an empty the array") { + withExternalArray(spillThreshold = 2) { array => + val iterator = array.generateIterator() + assert(!iterator.hasNext) + + // multiple clear'ing should not have an side-effect + array.clear() + array.clear() + array.clear() + assert(array.isEmpty) + assert(array.length == 0) + + // Clearing an empty array should also invalidate any old iterators + assert(!iterator.hasNext) + intercept[ConcurrentModificationException](iterator.next()) + } + } + + test("clear array (without spill)") { + val spillThreshold = 10 + withExternalArray(spillThreshold) { array => + // Populate rows ... but not enough to trigger spill + populateRows(array, spillThreshold / 2) + assertNoSpill() + + // Clear the array + array.clear() + assert(array.isEmpty) + + // Re-populate few rows so that there is no spill + // Verify the data. Verify that there was no spill + val expectedValues = populateRows(array, spillThreshold / 3) + validateData(array, expectedValues) + assertNoSpill() + + // Populate more rows .. enough to not trigger a spill. + // Verify the data. Verify that there was no spill + populateRows(array, spillThreshold / 3, expectedValues) + validateData(array, expectedValues) + assertNoSpill() + } + } + + test("clear array (with spill)") { + val spillThreshold = 10 + withExternalArray(spillThreshold) { array => + // Populate enough rows to trigger spill + populateRows(array, spillThreshold * 2) + val bytesSpilled = getNumBytesSpilled + assert(bytesSpilled > 0) + + // Clear the array + array.clear() + assert(array.isEmpty) + + // Re-populate the array ... but NOT upto the point that there is spill. + // Verify data. Verify that there was NO "extra" spill + val expectedValues = populateRows(array, spillThreshold / 2) + validateData(array, expectedValues) + assert(getNumBytesSpilled == bytesSpilled) + + // Populate more rows to trigger spill + // Verify the data. Verify that there was "extra" spill + populateRows(array, spillThreshold * 2, expectedValues) + validateData(array, expectedValues) + assert(getNumBytesSpilled > bytesSpilled) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLWindowFunctionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLWindowFunctionSuite.scala index afd47897ed4b..52e4f047225d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLWindowFunctionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLWindowFunctionSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.{AnalysisException, QueryTest, Row} import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.TestUtils.assertSpilled case class WindowData(month: Int, area: String, product: Int) @@ -412,4 +413,36 @@ class SQLWindowFunctionSuite extends QueryTest with SharedSQLContext { """.stripMargin), Row(1, 3, null) :: Row(2, null, 4) :: Nil) } + + test("test with low buffer spill threshold") { + val nums = sparkContext.parallelize(1 to 10).map(x => (x, x % 2)).toDF("x", "y") + nums.createOrReplaceTempView("nums") + + val expected = + Row(1, 1, 1) :: + Row(0, 2, 3) :: + Row(1, 3, 6) :: + Row(0, 4, 10) :: + Row(1, 5, 15) :: + Row(0, 6, 21) :: + Row(1, 7, 28) :: + Row(0, 8, 36) :: + Row(1, 9, 45) :: + Row(0, 10, 55) :: Nil + + val actual = sql( + """ + |SELECT y, x, sum(x) OVER w1 AS running_sum + |FROM nums + |WINDOW w1 AS (ORDER BY x ROWS BETWEEN UNBOUNDED PRECEDiNG AND CURRENT RoW) + """.stripMargin) + + withSQLConf("spark.sql.windowExec.buffer.spill.threshold" -> "1") { + assertSpilled(sparkContext, "test with low buffer spill threshold") { + checkAnswer(actual, expected) + } + } + + spark.catalog.dropTempView("nums") + } } From 97cc5e5a5555519d221d0ca78645dde9bb8ea40b Mon Sep 17 00:00:00 2001 From: jiangxingbo Date: Wed, 15 Mar 2017 14:58:19 -0700 Subject: [PATCH 042/512] [SPARK-19960][CORE] Move `SparkHadoopWriter` to `internal/io/` ## What changes were proposed in this pull request? This PR introduces the following changes: 1. Move `SparkHadoopWriter` to `core/internal/io/`, so that it's in the same directory with `SparkHadoopMapReduceWriter`; 2. Move `SparkHadoopWriterUtils` to a separated file. After this PR is merged, we may consolidate `SparkHadoopWriter` and `SparkHadoopMapReduceWriter`, and make the new commit protocol support the old `mapred` package's committer; ## How was this patch tested? Tested by existing test cases. Author: jiangxingbo Closes #17304 from jiangxb1987/writer. --- .../io/SparkHadoopMapReduceWriter.scala | 59 ------------ .../{ => internal/io}/SparkHadoopWriter.scala | 7 +- .../internal/io/SparkHadoopWriterUtils.scala | 93 +++++++++++++++++++ .../apache/spark/rdd/PairRDDFunctions.scala | 3 +- .../OutputCommitCoordinatorSuite.scala | 1 + 5 files changed, 99 insertions(+), 64 deletions(-) rename core/src/main/scala/org/apache/spark/{ => internal/io}/SparkHadoopWriter.scala (97%) create mode 100644 core/src/main/scala/org/apache/spark/internal/io/SparkHadoopWriterUtils.scala diff --git a/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopMapReduceWriter.scala b/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopMapReduceWriter.scala index 659ad5d0bad8..376ff9bb19f7 100644 --- a/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopMapReduceWriter.scala +++ b/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopMapReduceWriter.scala @@ -179,62 +179,3 @@ object SparkHadoopMapReduceWriter extends Logging { } } } - -private[spark] -object SparkHadoopWriterUtils { - - private val RECORDS_BETWEEN_BYTES_WRITTEN_METRIC_UPDATES = 256 - - def createJobID(time: Date, id: Int): JobID = { - val jobtrackerID = createJobTrackerID(time) - new JobID(jobtrackerID, id) - } - - def createJobTrackerID(time: Date): String = { - new SimpleDateFormat("yyyyMMddHHmmss", Locale.US).format(time) - } - - def createPathFromString(path: String, conf: JobConf): Path = { - if (path == null) { - throw new IllegalArgumentException("Output path is null") - } - val outputPath = new Path(path) - val fs = outputPath.getFileSystem(conf) - if (fs == null) { - throw new IllegalArgumentException("Incorrectly formatted output path") - } - outputPath.makeQualified(fs.getUri, fs.getWorkingDirectory) - } - - // Note: this needs to be a function instead of a 'val' so that the disableOutputSpecValidation - // setting can take effect: - def isOutputSpecValidationEnabled(conf: SparkConf): Boolean = { - val validationDisabled = disableOutputSpecValidation.value - val enabledInConf = conf.getBoolean("spark.hadoop.validateOutputSpecs", true) - enabledInConf && !validationDisabled - } - - // TODO: these don't seem like the right abstractions. - // We should abstract the duplicate code in a less awkward way. - - def initHadoopOutputMetrics(context: TaskContext): (OutputMetrics, () => Long) = { - val bytesWrittenCallback = SparkHadoopUtil.get.getFSBytesWrittenOnThreadCallback() - (context.taskMetrics().outputMetrics, bytesWrittenCallback) - } - - def maybeUpdateOutputMetrics( - outputMetrics: OutputMetrics, - callback: () => Long, - recordsWritten: Long): Unit = { - if (recordsWritten % RECORDS_BETWEEN_BYTES_WRITTEN_METRIC_UPDATES == 0) { - outputMetrics.setBytesWritten(callback()) - outputMetrics.setRecordsWritten(recordsWritten) - } - } - - /** - * Allows for the `spark.hadoop.validateOutputSpecs` checks to be disabled on a case-by-case - * basis; see SPARK-4835 for more details. - */ - val disableOutputSpecValidation: DynamicVariable[Boolean] = new DynamicVariable[Boolean](false) -} diff --git a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala b/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopWriter.scala similarity index 97% rename from core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala rename to core/src/main/scala/org/apache/spark/internal/io/SparkHadoopWriter.scala index 46e22b215b8e..acc9c3857100 100644 --- a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala +++ b/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopWriter.scala @@ -15,19 +15,18 @@ * limitations under the License. */ -package org.apache.spark +package org.apache.spark.internal.io import java.io.IOException -import java.text.NumberFormat -import java.text.SimpleDateFormat +import java.text.{NumberFormat, SimpleDateFormat} import java.util.{Date, Locale} import org.apache.hadoop.fs.FileSystem import org.apache.hadoop.mapred._ import org.apache.hadoop.mapreduce.TaskType +import org.apache.spark.SerializableWritable import org.apache.spark.internal.Logging -import org.apache.spark.internal.io.SparkHadoopWriterUtils import org.apache.spark.mapred.SparkHadoopMapRedUtil import org.apache.spark.rdd.HadoopRDD import org.apache.spark.util.SerializableJobConf diff --git a/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopWriterUtils.scala b/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopWriterUtils.scala new file mode 100644 index 000000000000..de828a6d6156 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopWriterUtils.scala @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.internal.io + +import java.text.SimpleDateFormat +import java.util.{Date, Locale} + +import scala.util.DynamicVariable + +import org.apache.hadoop.fs.Path +import org.apache.hadoop.mapred.{JobConf, JobID} + +import org.apache.spark.{SparkConf, TaskContext} +import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.executor.OutputMetrics + +/** + * A helper object that provide common utils used during saving an RDD using a Hadoop OutputFormat + * (both from the old mapred API and the new mapreduce API) + */ +private[spark] +object SparkHadoopWriterUtils { + + private val RECORDS_BETWEEN_BYTES_WRITTEN_METRIC_UPDATES = 256 + + def createJobID(time: Date, id: Int): JobID = { + val jobtrackerID = createJobTrackerID(time) + new JobID(jobtrackerID, id) + } + + def createJobTrackerID(time: Date): String = { + new SimpleDateFormat("yyyyMMddHHmmss", Locale.US).format(time) + } + + def createPathFromString(path: String, conf: JobConf): Path = { + if (path == null) { + throw new IllegalArgumentException("Output path is null") + } + val outputPath = new Path(path) + val fs = outputPath.getFileSystem(conf) + if (fs == null) { + throw new IllegalArgumentException("Incorrectly formatted output path") + } + outputPath.makeQualified(fs.getUri, fs.getWorkingDirectory) + } + + // Note: this needs to be a function instead of a 'val' so that the disableOutputSpecValidation + // setting can take effect: + def isOutputSpecValidationEnabled(conf: SparkConf): Boolean = { + val validationDisabled = disableOutputSpecValidation.value + val enabledInConf = conf.getBoolean("spark.hadoop.validateOutputSpecs", true) + enabledInConf && !validationDisabled + } + + // TODO: these don't seem like the right abstractions. + // We should abstract the duplicate code in a less awkward way. + + def initHadoopOutputMetrics(context: TaskContext): (OutputMetrics, () => Long) = { + val bytesWrittenCallback = SparkHadoopUtil.get.getFSBytesWrittenOnThreadCallback() + (context.taskMetrics().outputMetrics, bytesWrittenCallback) + } + + def maybeUpdateOutputMetrics( + outputMetrics: OutputMetrics, + callback: () => Long, + recordsWritten: Long): Unit = { + if (recordsWritten % RECORDS_BETWEEN_BYTES_WRITTEN_METRIC_UPDATES == 0) { + outputMetrics.setBytesWritten(callback()) + outputMetrics.setRecordsWritten(recordsWritten) + } + } + + /** + * Allows for the `spark.hadoop.validateOutputSpecs` checks to be disabled on a case-by-case + * basis; see SPARK-4835 for more details. + */ + val disableOutputSpecValidation: DynamicVariable[Boolean] = new DynamicVariable[Boolean](false) +} diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index 52ce03ff8cde..58762cc0838c 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -37,7 +37,8 @@ import org.apache.spark._ import org.apache.spark.Partitioner.defaultPartitioner import org.apache.spark.annotation.Experimental import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.internal.io.{SparkHadoopMapReduceWriter, SparkHadoopWriterUtils} +import org.apache.spark.internal.io.{SparkHadoopMapReduceWriter, SparkHadoopWriter, + SparkHadoopWriterUtils} import org.apache.spark.internal.Logging import org.apache.spark.partial.{BoundedDouble, PartialResult} import org.apache.spark.serializer.Serializer diff --git a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala index 83ed12752074..38b9d40329d4 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala @@ -31,6 +31,7 @@ import org.mockito.stubbing.Answer import org.scalatest.BeforeAndAfter import org.apache.spark._ +import org.apache.spark.internal.io.SparkHadoopWriter import org.apache.spark.rdd.{FakeOutputCommitter, RDD} import org.apache.spark.util.{ThreadUtils, Utils} From 54a3697f1fb562ef9ed8fed9caffc62b84763049 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Wed, 15 Mar 2017 15:01:16 -0700 Subject: [PATCH 043/512] [MINOR][CORE] Fix a info message of `prunePartitions` ## What changes were proposed in this pull request? `PrunedInMemoryFileIndex.prunePartitions` shows `pruned NaN% partitions` for the following case. ```scala scala> Seq.empty[(String, String)].toDF("a", "p").write.partitionBy("p").saveAsTable("t1") scala> sc.setLogLevel("INFO") scala> spark.table("t1").filter($"p" === "1").select($"a").show ... 17/03/13 00:33:04 INFO PrunedInMemoryFileIndex: Selected 0 partitions out of 0, pruned NaN% partitions. ``` After this PR, the message looks like this. ```scala 17/03/15 10:39:48 INFO PrunedInMemoryFileIndex: Selected 0 partitions out of 0, pruned 0 partitions. ``` ## How was this patch tested? Pass the Jenkins with the existing tests. Author: Dongjoon Hyun Closes #17273 from dongjoon-hyun/SPARK-EMPTY-PARTITION. --- .../sql/execution/datasources/PartitioningAwareFileIndex.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala index a5fa8b3f9385..db8bbc52aaf4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala @@ -186,7 +186,8 @@ abstract class PartitioningAwareFileIndex( val total = partitions.length val selectedSize = selected.length val percentPruned = (1 - selectedSize.toDouble / total.toDouble) * 100 - s"Selected $selectedSize partitions out of $total, pruned $percentPruned% partitions." + s"Selected $selectedSize partitions out of $total, " + + s"pruned ${if (total == 0) "0" else s"$percentPruned%"} partitions." } selected From 046b8d4aef00b0701cf7e4b99aeaf450cacb42fe Mon Sep 17 00:00:00 2001 From: erenavsarogullari Date: Wed, 15 Mar 2017 15:57:51 -0700 Subject: [PATCH 044/512] [SPARK-18066][CORE][TESTS] Add Pool usage policies test coverage for FIFO & FAIR Schedulers ## What changes were proposed in this pull request? The following FIFO & FAIR Schedulers Pool usage cases need to have unit test coverage : - FIFO Scheduler just uses **root pool** so even if `spark.scheduler.pool` property is set, related pool is not created and `TaskSetManagers` are added to **root pool**. - FAIR Scheduler uses `default pool` when `spark.scheduler.pool` property is not set. This can be happened when - `Properties` object is **null**, - `Properties` object is **empty**(`new Properties()`), - **default pool** is set(`spark.scheduler.pool=default`). - FAIR Scheduler creates a **new pool** with **default values** when `spark.scheduler.pool` property points a **non-existent** pool. This can be happened when **scheduler allocation file** is not set or it does not contain related pool. ## How was this patch tested? New Unit tests are added. Author: erenavsarogullari Closes #15604 from erenavsarogullari/SPARK-18066. --- .../spark/scheduler/SchedulableBuilder.scala | 7 +- .../apache/spark/scheduler/PoolSuite.scala | 97 +++++++++++++++++-- 2 files changed, 96 insertions(+), 8 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala b/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala index e53c4fb5b477..20cedaf06042 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala @@ -191,8 +191,11 @@ private[spark] class FairSchedulableBuilder(val rootPool: Pool, conf: SparkConf) parentPool = new Pool(poolName, DEFAULT_SCHEDULING_MODE, DEFAULT_MINIMUM_SHARE, DEFAULT_WEIGHT) rootPool.addSchedulable(parentPool) - logInfo("Created pool: %s, schedulingMode: %s, minShare: %d, weight: %d".format( - poolName, DEFAULT_SCHEDULING_MODE, DEFAULT_MINIMUM_SHARE, DEFAULT_WEIGHT)) + logWarning(s"A job was submitted with scheduler pool $poolName, which has not been " + + "configured. This can happen when the file that pools are read from isn't set, or " + + s"when that file doesn't contain $poolName. Created $poolName with default " + + s"configuration (schedulingMode: $DEFAULT_SCHEDULING_MODE, " + + s"minShare: $DEFAULT_MINIMUM_SHARE, weight: $DEFAULT_WEIGHT)") } } parentPool.addSchedulable(manager) diff --git a/core/src/test/scala/org/apache/spark/scheduler/PoolSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/PoolSuite.scala index 520736ab6427..cddff3dd3586 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/PoolSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/PoolSuite.scala @@ -31,6 +31,7 @@ class PoolSuite extends SparkFunSuite with LocalSparkContext { val LOCAL = "local" val APP_NAME = "PoolSuite" val SCHEDULER_ALLOCATION_FILE_PROPERTY = "spark.scheduler.allocation.file" + val TEST_POOL = "testPool" def createTaskSetManager(stageId: Int, numTasks: Int, taskScheduler: TaskSchedulerImpl) : TaskSetManager = { @@ -40,7 +41,7 @@ class PoolSuite extends SparkFunSuite with LocalSparkContext { new TaskSetManager(taskScheduler, new TaskSet(tasks, stageId, 0, 0, null), 0) } - def scheduleTaskAndVerifyId(taskId: Int, rootPool: Pool, expectedStageId: Int) { + def scheduleTaskAndVerifyId(taskId: Int, rootPool: Pool, expectedStageId: Int): Unit = { val taskSetQueue = rootPool.getSortedTaskSetQueue val nextTaskSetToSchedule = taskSetQueue.find(t => (t.runningTasks + t.tasksSuccessful) < t.numTasks) @@ -201,12 +202,96 @@ class PoolSuite extends SparkFunSuite with LocalSparkContext { verifyPool(rootPool, "pool_with_surrounded_whitespace", 3, 2, FAIR) } + /** + * spark.scheduler.pool property should be ignored for the FIFO scheduler, + * because pools are only needed for fair scheduling. + */ + test("FIFO scheduler uses root pool and not spark.scheduler.pool property") { + sc = new SparkContext("local", "PoolSuite") + val taskScheduler = new TaskSchedulerImpl(sc) + + val rootPool = new Pool("", SchedulingMode.FIFO, initMinShare = 0, initWeight = 0) + val schedulableBuilder = new FIFOSchedulableBuilder(rootPool) + + val taskSetManager0 = createTaskSetManager(stageId = 0, numTasks = 1, taskScheduler) + val taskSetManager1 = createTaskSetManager(stageId = 1, numTasks = 1, taskScheduler) + + val properties = new Properties() + properties.setProperty("spark.scheduler.pool", TEST_POOL) + + // When FIFO Scheduler is used and task sets are submitted, they should be added to + // the root pool, and no additional pools should be created + // (even though there's a configured default pool). + schedulableBuilder.addTaskSetManager(taskSetManager0, properties) + schedulableBuilder.addTaskSetManager(taskSetManager1, properties) + + assert(rootPool.getSchedulableByName(TEST_POOL) === null) + assert(rootPool.schedulableQueue.size === 2) + assert(rootPool.getSchedulableByName(taskSetManager0.name) === taskSetManager0) + assert(rootPool.getSchedulableByName(taskSetManager1.name) === taskSetManager1) + } + + test("FAIR Scheduler uses default pool when spark.scheduler.pool property is not set") { + sc = new SparkContext("local", "PoolSuite") + val taskScheduler = new TaskSchedulerImpl(sc) + + val rootPool = new Pool("", SchedulingMode.FAIR, initMinShare = 0, initWeight = 0) + val schedulableBuilder = new FairSchedulableBuilder(rootPool, sc.conf) + schedulableBuilder.buildPools() + + // Submit a new task set manager with pool properties set to null. This should result + // in the task set manager getting added to the default pool. + val taskSetManager0 = createTaskSetManager(stageId = 0, numTasks = 1, taskScheduler) + schedulableBuilder.addTaskSetManager(taskSetManager0, null) + + val defaultPool = rootPool.getSchedulableByName(schedulableBuilder.DEFAULT_POOL_NAME) + assert(defaultPool !== null) + assert(defaultPool.schedulableQueue.size === 1) + assert(defaultPool.getSchedulableByName(taskSetManager0.name) === taskSetManager0) + + // When a task set manager is submitted with spark.scheduler.pool unset, it should be added to + // the default pool (as above). + val taskSetManager1 = createTaskSetManager(stageId = 1, numTasks = 1, taskScheduler) + schedulableBuilder.addTaskSetManager(taskSetManager1, new Properties()) + + assert(defaultPool.schedulableQueue.size === 2) + assert(defaultPool.getSchedulableByName(taskSetManager1.name) === taskSetManager1) + } + + test("FAIR Scheduler creates a new pool when spark.scheduler.pool property points to " + + "a non-existent pool") { + sc = new SparkContext("local", "PoolSuite") + val taskScheduler = new TaskSchedulerImpl(sc) + + val rootPool = new Pool("", SchedulingMode.FAIR, initMinShare = 0, initWeight = 0) + val schedulableBuilder = new FairSchedulableBuilder(rootPool, sc.conf) + schedulableBuilder.buildPools() + + assert(rootPool.getSchedulableByName(TEST_POOL) === null) + + val taskSetManager = createTaskSetManager(stageId = 0, numTasks = 1, taskScheduler) + + val properties = new Properties() + properties.setProperty(schedulableBuilder.FAIR_SCHEDULER_PROPERTIES, TEST_POOL) + + // The fair scheduler should create a new pool with default values when spark.scheduler.pool + // points to a pool that doesn't exist yet (this can happen when the file that pools are read + // from isn't set, or when that file doesn't contain the pool name specified + // by spark.scheduler.pool). + schedulableBuilder.addTaskSetManager(taskSetManager, properties) + + verifyPool(rootPool, TEST_POOL, schedulableBuilder.DEFAULT_MINIMUM_SHARE, + schedulableBuilder.DEFAULT_WEIGHT, schedulableBuilder.DEFAULT_SCHEDULING_MODE) + val testPool = rootPool.getSchedulableByName(TEST_POOL) + assert(testPool.getSchedulableByName(taskSetManager.name) === taskSetManager) + } + private def verifyPool(rootPool: Pool, poolName: String, expectedInitMinShare: Int, expectedInitWeight: Int, expectedSchedulingMode: SchedulingMode): Unit = { - assert(rootPool.getSchedulableByName(poolName) != null) - assert(rootPool.getSchedulableByName(poolName).minShare === expectedInitMinShare) - assert(rootPool.getSchedulableByName(poolName).weight === expectedInitWeight) - assert(rootPool.getSchedulableByName(poolName).schedulingMode === expectedSchedulingMode) + val selectedPool = rootPool.getSchedulableByName(poolName) + assert(selectedPool !== null) + assert(selectedPool.minShare === expectedInitMinShare) + assert(selectedPool.weight === expectedInitWeight) + assert(selectedPool.schedulingMode === expectedSchedulingMode) } - } From 7d734a658349e8691d8b4294454c9cd98d555014 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 16 Mar 2017 08:18:36 +0800 Subject: [PATCH 045/512] [SPARK-19931][SQL] InMemoryTableScanExec should rewrite output partitioning and ordering when aliasing output attributes ## What changes were proposed in this pull request? Now `InMemoryTableScanExec` simply takes the `outputPartitioning` and `outputOrdering` from the associated `InMemoryRelation`'s `child.outputPartitioning` and `outputOrdering`. However, `InMemoryTableScanExec` can alias the output attributes. In this case, its `outputPartitioning` and `outputOrdering` are not correct and its parent operators can't correctly determine its data distribution. ## How was this patch tested? Jenkins tests. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Liang-Chi Hsieh Closes #17175 from viirya/ensure-no-unnecessary-shuffle. --- .../columnar/InMemoryTableScanExec.scala | 21 ++++++++++++--- .../columnar/InMemoryColumnarQuerySuite.scala | 26 +++++++++++++++++++ 2 files changed, 44 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala index 9028caa446e8..214e8d309de1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.QueryPlan -import org.apache.spark.sql.catalyst.plans.physical.Partitioning +import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning} import org.apache.spark.sql.execution.LeafExecNode import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.types.UserDefinedType @@ -41,11 +41,26 @@ case class InMemoryTableScanExec( override def output: Seq[Attribute] = attributes + private def updateAttribute(expr: Expression): Expression = { + val attrMap = AttributeMap(relation.child.output.zip(output)) + expr.transform { + case attr: Attribute => attrMap.getOrElse(attr, attr) + } + } + // The cached version does not change the outputPartitioning of the original SparkPlan. - override def outputPartitioning: Partitioning = relation.child.outputPartitioning + // But the cached version could alias output, so we need to replace output. + override def outputPartitioning: Partitioning = { + relation.child.outputPartitioning match { + case h: HashPartitioning => updateAttribute(h).asInstanceOf[HashPartitioning] + case _ => relation.child.outputPartitioning + } + } // The cached version does not change the outputOrdering of the original SparkPlan. - override def outputOrdering: Seq[SortOrder] = relation.child.outputOrdering + // But the cached version could alias output, so we need to replace output. + override def outputOrdering: Seq[SortOrder] = + relation.child.outputOrdering.map(updateAttribute(_).asInstanceOf[SortOrder]) private def statsFor(a: Attribute) = relation.partitionStatistics.forAttribute(a) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala index 0250a53fe232..1e6a6a8ba336 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala @@ -21,6 +21,9 @@ import java.nio.charset.StandardCharsets import java.sql.{Date, Timestamp} import org.apache.spark.sql.{DataFrame, QueryTest, Row} +import org.apache.spark.sql.catalyst.expressions.AttributeSet +import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning +import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.test.SQLTestData._ @@ -388,4 +391,27 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { } } + test("InMemoryTableScanExec should return correct output ordering and partitioning") { + val df1 = Seq((0, 0), (1, 1)).toDF + .repartition(col("_1")).sortWithinPartitions(col("_1")).persist + val df2 = Seq((0, 0), (1, 1)).toDF + .repartition(col("_1")).sortWithinPartitions(col("_1")).persist + + // Because two cached dataframes have the same logical plan, this is a self-join actually. + // So we force one of in-memory relation to alias its output. Then we can test if original and + // aliased in-memory relations have correct ordering and partitioning. + val joined = df1.joinWith(df2, df1("_1") === df2("_1")) + + val inMemoryScans = joined.queryExecution.executedPlan.collect { + case m: InMemoryTableScanExec => m + } + inMemoryScans.foreach { inMemoryScan => + val sortedAttrs = AttributeSet(inMemoryScan.outputOrdering.flatMap(_.references)) + assert(sortedAttrs.subsetOf(inMemoryScan.outputSet)) + + val partitionedAttrs = + inMemoryScan.outputPartitioning.asInstanceOf[HashPartitioning].references + assert(partitionedAttrs.subsetOf(inMemoryScan.outputSet)) + } + } } From 339b237dc18d4367b0735236b4b8be2901fcad79 Mon Sep 17 00:00:00 2001 From: Juliusz Sompolski Date: Thu, 16 Mar 2017 08:20:47 +0800 Subject: [PATCH 046/512] [SPARK-19948] Document that saveAsTable uses catalog as source of truth for table existence. It is quirky behaviour that saveAsTable to e.g. a JDBC source with SaveMode other than Overwrite will nevertheless overwrite the table in the external source, if that table was not a catalog table. Author: Juliusz Sompolski Closes #17289 from juliuszsompolski/saveAsTableDoc. --- .../main/scala/org/apache/spark/sql/DataFrameWriter.scala | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index deaa8006945c..3e975ef6a3c2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -337,6 +337,11 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { * +---+---+ * }}} * + * In this method, save mode is used to determine the behavior if the data source table exists in + * Spark catalog. We will always overwrite the underlying data of data source (e.g. a table in + * JDBC data source) if the table doesn't exist in Spark catalog, and will always append to the + * underlying data of data source if the table already exists. + * * When the DataFrame is created from a non-partitioned `HadoopFsRelation` with a single input * path, and the data source provider can be mapped to an existing Hive builtin SerDe (i.e. ORC * and Parquet), the table is persisted in a Hive compatible format, which means other systems From fc9314671c8a082ae339fd6df177a2b684c65d40 Mon Sep 17 00:00:00 2001 From: windpiger Date: Thu, 16 Mar 2017 08:44:57 +0800 Subject: [PATCH 047/512] [SPARK-19961][SQL][MINOR] unify a erro msg when drop databse for HiveExternalCatalog and InMemoryCatalog ## What changes were proposed in this pull request? unify a exception erro msg for dropdatabase when the database still have some tables for HiveExternalCatalog and InMemoryCatalog ## How was this patch tested? N/A Author: windpiger Closes #17305 from windpiger/unifyErromsg. --- .../spark/sql/catalyst/catalog/InMemoryCatalog.scala | 2 +- .../org/apache/spark/sql/execution/command/DDLSuite.scala | 8 ++------ 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala index 5cc6b0abc6fd..cdf618aef97c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala @@ -127,7 +127,7 @@ class InMemoryCatalog( if (!cascade) { // If cascade is false, make sure the database is empty. if (catalog(db).tables.nonEmpty) { - throw new AnalysisException(s"Database '$db' is not empty. One or more tables exist.") + throw new AnalysisException(s"Database $db is not empty. One or more tables exist.") } if (catalog(db).functions.nonEmpty) { throw new AnalysisException(s"Database '$db' is not empty. One or more functions exist.") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index 6eed10ec5146..dd76fdde06cd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -617,12 +617,8 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { val message = intercept[AnalysisException] { sql(s"DROP DATABASE $dbName RESTRICT") }.getMessage - // TODO: Unify the exception. - if (isUsingHiveMetastore) { - assert(message.contains(s"Database $dbName is not empty. One or more tables exist")) - } else { - assert(message.contains(s"Database '$dbName' is not empty. One or more tables exist")) - } + assert(message.contains(s"Database $dbName is not empty. One or more tables exist")) + catalog.dropTable(tableIdent1, ignoreIfNotExists = false, purge = false) From 21f333c635465069b7657d788052d510ffb0779a Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Thu, 16 Mar 2017 08:50:01 +0800 Subject: [PATCH 048/512] [SPARK-19751][SQL] Throw an exception if bean class has one's own class in fields ## What changes were proposed in this pull request? The current master throws `StackOverflowError` in `createDataFrame`/`createDataset` if bean has one's own class in fields; ``` public class SelfClassInFieldBean implements Serializable { private SelfClassInFieldBean child; ... } ``` This pr added code to throw `UnsupportedOperationException` in that case as soon as possible. ## How was this patch tested? Added tests in `JavaDataFrameSuite` and `JavaDatasetSuite`. Author: Takeshi Yamamuro Closes #17188 from maropu/SPARK-19751. --- .../sql/catalyst/JavaTypeInference.scala | 19 ++-- .../apache/spark/sql/JavaDataFrameSuite.java | 32 +++++++ .../apache/spark/sql/JavaDatasetSuite.java | 87 +++++++++++++++++++ 3 files changed, 132 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index e9d9508e5adf..4ff87edde139 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -69,7 +69,8 @@ object JavaTypeInference { * @param typeToken Java type * @return (SQL data type, nullable) */ - private def inferDataType(typeToken: TypeToken[_]): (DataType, Boolean) = { + private def inferDataType(typeToken: TypeToken[_], seenTypeSet: Set[Class[_]] = Set.empty) + : (DataType, Boolean) = { typeToken.getRawType match { case c: Class[_] if c.isAnnotationPresent(classOf[SQLUserDefinedType]) => (c.getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance(), true) @@ -104,26 +105,32 @@ object JavaTypeInference { case c: Class[_] if c == classOf[java.sql.Timestamp] => (TimestampType, true) case _ if typeToken.isArray => - val (dataType, nullable) = inferDataType(typeToken.getComponentType) + val (dataType, nullable) = inferDataType(typeToken.getComponentType, seenTypeSet) (ArrayType(dataType, nullable), true) case _ if iterableType.isAssignableFrom(typeToken) => - val (dataType, nullable) = inferDataType(elementType(typeToken)) + val (dataType, nullable) = inferDataType(elementType(typeToken), seenTypeSet) (ArrayType(dataType, nullable), true) case _ if mapType.isAssignableFrom(typeToken) => val (keyType, valueType) = mapKeyValueType(typeToken) - val (keyDataType, _) = inferDataType(keyType) - val (valueDataType, nullable) = inferDataType(valueType) + val (keyDataType, _) = inferDataType(keyType, seenTypeSet) + val (valueDataType, nullable) = inferDataType(valueType, seenTypeSet) (MapType(keyDataType, valueDataType, nullable), true) case other => + if (seenTypeSet.contains(other)) { + throw new UnsupportedOperationException( + "Cannot have circular references in bean class, but got the circular reference " + + s"of class $other") + } + // TODO: we should only collect properties that have getter and setter. However, some tests // pass in scala case class as java bean class which doesn't have getter and setter. val properties = getJavaBeanReadableProperties(other) val fields = properties.map { property => val returnType = typeToken.method(property.getReadMethod).getReturnType - val (dataType, nullable) = inferDataType(returnType) + val (dataType, nullable) = inferDataType(returnType, seenTypeSet + other) new StructField(property.getName, dataType, nullable) } (new StructType(fields), true) diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index be8d95d0d912..b007093dad84 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -423,4 +423,36 @@ public void testJsonRDDToDataFrame() { Assert.assertEquals(1L, df.count()); Assert.assertEquals(2L, df.collectAsList().get(0).getLong(0)); } + + public class CircularReference1Bean implements Serializable { + private CircularReference2Bean child; + + public CircularReference2Bean getChild() { + return child; + } + + public void setChild(CircularReference2Bean child) { + this.child = child; + } + } + + public class CircularReference2Bean implements Serializable { + private CircularReference1Bean child; + + public CircularReference1Bean getChild() { + return child; + } + + public void setChild(CircularReference1Bean child) { + this.child = child; + } + } + + // Checks a simple case for DataFrame here and put exhaustive tests for the issue + // of circular references in `JavaDatasetSuite`. + @Test(expected = UnsupportedOperationException.class) + public void testCircularReferenceBean() { + CircularReference1Bean bean = new CircularReference1Bean(); + spark.createDataFrame(Arrays.asList(bean), CircularReference1Bean.class); + } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index d06e35bb44d0..439cac3dfbcb 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -1291,4 +1291,91 @@ public void testEmptyBean() { Assert.assertEquals(df.schema().length(), 0); Assert.assertEquals(df.collectAsList().size(), 1); } + + public class CircularReference1Bean implements Serializable { + private CircularReference2Bean child; + + public CircularReference2Bean getChild() { + return child; + } + + public void setChild(CircularReference2Bean child) { + this.child = child; + } + } + + public class CircularReference2Bean implements Serializable { + private CircularReference1Bean child; + + public CircularReference1Bean getChild() { + return child; + } + + public void setChild(CircularReference1Bean child) { + this.child = child; + } + } + + public class CircularReference3Bean implements Serializable { + private CircularReference3Bean[] child; + + public CircularReference3Bean[] getChild() { + return child; + } + + public void setChild(CircularReference3Bean[] child) { + this.child = child; + } + } + + public class CircularReference4Bean implements Serializable { + private Map child; + + public Map getChild() { + return child; + } + + public void setChild(Map child) { + this.child = child; + } + } + + public class CircularReference5Bean implements Serializable { + private String id; + private List child; + + public String getId() { + return id; + } + + public List getChild() { + return child; + } + + public void setId(String id) { + this.id = id; + } + + public void setChild(List child) { + this.child = child; + } + } + + @Test(expected = UnsupportedOperationException.class) + public void testCircularReferenceBean1() { + CircularReference1Bean bean = new CircularReference1Bean(); + spark.createDataset(Arrays.asList(bean), Encoders.bean(CircularReference1Bean.class)); + } + + @Test(expected = UnsupportedOperationException.class) + public void testCircularReferenceBean2() { + CircularReference3Bean bean = new CircularReference3Bean(); + spark.createDataset(Arrays.asList(bean), Encoders.bean(CircularReference3Bean.class)); + } + + @Test(expected = UnsupportedOperationException.class) + public void testCircularReferenceBean3() { + CircularReference4Bean bean = new CircularReference4Bean(); + spark.createDataset(Arrays.asList(bean), Encoders.bean(CircularReference4Bean.class)); + } } From 1472cac4bb31c1886f82830778d34c4dd9030d7a Mon Sep 17 00:00:00 2001 From: Xiao Li Date: Thu, 16 Mar 2017 12:06:20 +0800 Subject: [PATCH 049/512] [SPARK-19830][SQL] Add parseTableSchema API to ParserInterface ### What changes were proposed in this pull request? Specifying the table schema in DDL formats is needed for different scenarios. For example, - [specifying the schema in SQL function `from_json` using DDL formats](https://issues.apache.org/jira/browse/SPARK-19637), which is suggested by marmbrus , - [specifying the customized JDBC data types](https://github.com/apache/spark/pull/16209). These two PRs need users to use the JSON format to specify the table schema. This is not user friendly. This PR is to provide a `parseTableSchema` API in `ParserInterface`. ### How was this patch tested? Added a test suite `TableSchemaParserSuite` Author: Xiao Li Closes #17171 from gatorsmile/parseDDLStmt. --- .../sql/catalyst/parser/ParseDriver.scala | 10 ++- .../sql/catalyst/parser/ParserInterface.scala | 7 ++ .../parser/TableSchemaParserSuite.scala | 88 +++++++++++++++++++ 3 files changed, 104 insertions(+), 1 deletion(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableSchemaParserSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala index d687a85c18b6..f704b0998cad 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.trees.Origin -import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.types.{DataType, StructType} /** * Base SQL parsing infrastructure. @@ -49,6 +49,14 @@ abstract class AbstractSqlParser extends ParserInterface with Logging { astBuilder.visitSingleTableIdentifier(parser.singleTableIdentifier()) } + /** + * Creates StructType for a given SQL string, which is a comma separated list of field + * definitions which will preserve the correct Hive metadata. + */ + override def parseTableSchema(sqlText: String): StructType = parse(sqlText) { parser => + StructType(astBuilder.visitColTypeList(parser.colTypeList())) + } + /** Creates LogicalPlan for a given SQL string. */ override def parsePlan(sqlText: String): LogicalPlan = parse(sqlText) { parser => astBuilder.visitSingleStatement(parser.singleStatement()) match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserInterface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserInterface.scala index 7f35d650b957..6edbe253970e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserInterface.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserInterface.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.parser import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.types.StructType /** * Interface for a parser. @@ -33,4 +34,10 @@ trait ParserInterface { /** Creates TableIdentifier for a given SQL string. */ def parseTableIdentifier(sqlText: String): TableIdentifier + + /** + * Creates StructType for a given SQL string, which is a comma separated list of field + * definitions which will preserve the correct Hive metadata. + */ + def parseTableSchema(sqlText: String): StructType } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableSchemaParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableSchemaParserSuite.scala new file mode 100644 index 000000000000..da1041d61708 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableSchemaParserSuite.scala @@ -0,0 +1,88 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.catalyst.parser + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.types._ + +class TableSchemaParserSuite extends SparkFunSuite { + + def parse(sql: String): StructType = CatalystSqlParser.parseTableSchema(sql) + + def checkTableSchema(tableSchemaString: String, expectedDataType: DataType): Unit = { + test(s"parse $tableSchemaString") { + assert(parse(tableSchemaString) === expectedDataType) + } + } + + def assertError(sql: String): Unit = + intercept[ParseException](CatalystSqlParser.parseTableSchema(sql)) + + checkTableSchema("a int", new StructType().add("a", "int")) + checkTableSchema("A int", new StructType().add("A", "int")) + checkTableSchema("a INT", new StructType().add("a", "int")) + checkTableSchema("`!@#$%.^&*()` string", new StructType().add("!@#$%.^&*()", "string")) + checkTableSchema("a int, b long", new StructType().add("a", "int").add("b", "long")) + checkTableSchema("a STRUCT", + StructType( + StructField("a", StructType( + StructField("intType", IntegerType) :: + StructField("ts", TimestampType) :: Nil)) :: Nil)) + checkTableSchema( + "a int comment 'test'", + new StructType().add("a", "int", nullable = true, "test")) + + test("complex hive type") { + val tableSchemaString = + """ + |complexStructCol struct< + |struct:struct, + |MAP:Map, + |arrAy:Array, + |anotherArray:Array> + """.stripMargin.replace("\n", "") + + val builder = new MetadataBuilder + builder.putString(HIVE_TYPE_STRING, + "struct," + + "MAP:map,arrAy:array,anotherArray:array>") + + val expectedDataType = + StructType( + StructField("complexStructCol", StructType( + StructField("struct", + StructType( + StructField("deciMal", DecimalType.USER_DEFAULT) :: + StructField("anotherDecimal", DecimalType(5, 2)) :: Nil)) :: + StructField("MAP", MapType(TimestampType, StringType)) :: + StructField("arrAy", ArrayType(DoubleType)) :: + StructField("anotherArray", ArrayType(StringType)) :: Nil), + nullable = true, + builder.build()) :: Nil) + + assert(parse(tableSchemaString) === expectedDataType) + } + + // Negative cases + assertError("") + assertError("a") + assertError("a INT b long") + assertError("a INT,, b long") + assertError("a INT, b long,,") + assertError("a INT, b long, c int,") +} From d647aae278ef31a07fc64715eb07e48294d94bb8 Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Thu, 16 Mar 2017 12:49:59 +0200 Subject: [PATCH 050/512] [SPARK-13568][ML] Create feature transformer to impute missing values ## What changes were proposed in this pull request? jira: https://issues.apache.org/jira/browse/SPARK-13568 It is quite common to encounter missing values in data sets. It would be useful to implement a Transformer that can impute missing data points, similar to e.g. Imputer in scikit-learn. Initially, options for imputation could include mean, median and most frequent, but we could add various other approaches, where possible existing DataFrame code can be used (e.g. for approximate quantiles etc). Currently this PR supports imputation for Double and Vector (null and NaN in Vector). ## How was this patch tested? new unit tests and manual test Author: Yuhao Yang Author: Yuhao Yang Author: Yuhao Closes #11601 from hhbyyh/imputer. --- .../org/apache/spark/ml/feature/Imputer.scala | 259 ++++++++++++++++++ .../spark/ml/feature/ImputerSuite.scala | 185 +++++++++++++ 2 files changed, 444 insertions(+) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/feature/ImputerSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala new file mode 100644 index 000000000000..b1a802ee13fc --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala @@ -0,0 +1,259 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.hadoop.fs.Path + +import org.apache.spark.SparkException +import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.ml.param._ +import org.apache.spark.ml.param.shared.HasInputCols +import org.apache.spark.ml.util._ +import org.apache.spark.sql.{DataFrame, Dataset, Row} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types._ + +/** + * Params for [[Imputer]] and [[ImputerModel]]. + */ +private[feature] trait ImputerParams extends Params with HasInputCols { + + /** + * The imputation strategy. + * If "mean", then replace missing values using the mean value of the feature. + * If "median", then replace missing values using the approximate median value of the feature. + * Default: mean + * + * @group param + */ + final val strategy: Param[String] = new Param(this, "strategy", s"strategy for imputation. " + + s"If ${Imputer.mean}, then replace missing values using the mean value of the feature. " + + s"If ${Imputer.median}, then replace missing values using the median value of the feature.", + ParamValidators.inArray[String](Array(Imputer.mean, Imputer.median))) + + /** @group getParam */ + def getStrategy: String = $(strategy) + + /** + * The placeholder for the missing values. All occurrences of missingValue will be imputed. + * Note that null values are always treated as missing. + * Default: Double.NaN + * + * @group param + */ + final val missingValue: DoubleParam = new DoubleParam(this, "missingValue", + "The placeholder for the missing values. All occurrences of missingValue will be imputed") + + /** @group getParam */ + def getMissingValue: Double = $(missingValue) + + /** + * Param for output column names. + * @group param + */ + final val outputCols: StringArrayParam = new StringArrayParam(this, "outputCols", + "output column names") + + /** @group getParam */ + final def getOutputCols: Array[String] = $(outputCols) + + /** Validates and transforms the input schema. */ + protected def validateAndTransformSchema(schema: StructType): StructType = { + require($(inputCols).length == $(inputCols).distinct.length, s"inputCols contains" + + s" duplicates: (${$(inputCols).mkString(", ")})") + require($(outputCols).length == $(outputCols).distinct.length, s"outputCols contains" + + s" duplicates: (${$(outputCols).mkString(", ")})") + require($(inputCols).length == $(outputCols).length, s"inputCols(${$(inputCols).length})" + + s" and outputCols(${$(outputCols).length}) should have the same length") + val outputFields = $(inputCols).zip($(outputCols)).map { case (inputCol, outputCol) => + val inputField = schema(inputCol) + SchemaUtils.checkColumnTypes(schema, inputCol, Seq(DoubleType, FloatType)) + StructField(outputCol, inputField.dataType, inputField.nullable) + } + StructType(schema ++ outputFields) + } +} + +/** + * :: Experimental :: + * Imputation estimator for completing missing values, either using the mean or the median + * of the column in which the missing values are located. The input column should be of + * DoubleType or FloatType. Currently Imputer does not support categorical features yet + * (SPARK-15041) and possibly creates incorrect values for a categorical feature. + * + * Note that the mean/median value is computed after filtering out missing values. + * All Null values in the input column are treated as missing, and so are also imputed. For + * computing median, DataFrameStatFunctions.approxQuantile is used with a relative error of 0.001. + */ +@Experimental +class Imputer @Since("2.2.0")(override val uid: String) + extends Estimator[ImputerModel] with ImputerParams with DefaultParamsWritable { + + @Since("2.2.0") + def this() = this(Identifiable.randomUID("imputer")) + + /** @group setParam */ + @Since("2.2.0") + def setInputCols(value: Array[String]): this.type = set(inputCols, value) + + /** @group setParam */ + @Since("2.2.0") + def setOutputCols(value: Array[String]): this.type = set(outputCols, value) + + /** + * Imputation strategy. Available options are ["mean", "median"]. + * @group setParam + */ + @Since("2.2.0") + def setStrategy(value: String): this.type = set(strategy, value) + + /** @group setParam */ + @Since("2.2.0") + def setMissingValue(value: Double): this.type = set(missingValue, value) + + setDefault(strategy -> Imputer.mean, missingValue -> Double.NaN) + + override def fit(dataset: Dataset[_]): ImputerModel = { + transformSchema(dataset.schema, logging = true) + val spark = dataset.sparkSession + import spark.implicits._ + val surrogates = $(inputCols).map { inputCol => + val ic = col(inputCol) + val filtered = dataset.select(ic.cast(DoubleType)) + .filter(ic.isNotNull && ic =!= $(missingValue) && !ic.isNaN) + if(filtered.take(1).length == 0) { + throw new SparkException(s"surrogate cannot be computed. " + + s"All the values in $inputCol are Null, Nan or missingValue(${$(missingValue)})") + } + val surrogate = $(strategy) match { + case Imputer.mean => filtered.select(avg(inputCol)).as[Double].first() + case Imputer.median => filtered.stat.approxQuantile(inputCol, Array(0.5), 0.001).head + } + surrogate + } + + val rows = spark.sparkContext.parallelize(Seq(Row.fromSeq(surrogates))) + val schema = StructType($(inputCols).map(col => StructField(col, DoubleType, nullable = false))) + val surrogateDF = spark.createDataFrame(rows, schema) + copyValues(new ImputerModel(uid, surrogateDF).setParent(this)) + } + + override def transformSchema(schema: StructType): StructType = { + validateAndTransformSchema(schema) + } + + override def copy(extra: ParamMap): Imputer = defaultCopy(extra) +} + +@Since("2.2.0") +object Imputer extends DefaultParamsReadable[Imputer] { + + /** strategy names that Imputer currently supports. */ + private[ml] val mean = "mean" + private[ml] val median = "median" + + @Since("2.2.0") + override def load(path: String): Imputer = super.load(path) +} + +/** + * :: Experimental :: + * Model fitted by [[Imputer]]. + * + * @param surrogateDF a DataFrame contains inputCols and their corresponding surrogates, which are + * used to replace the missing values in the input DataFrame. + */ +@Experimental +class ImputerModel private[ml]( + override val uid: String, + val surrogateDF: DataFrame) + extends Model[ImputerModel] with ImputerParams with MLWritable { + + import ImputerModel._ + + /** @group setParam */ + def setInputCols(value: Array[String]): this.type = set(inputCols, value) + + /** @group setParam */ + def setOutputCols(value: Array[String]): this.type = set(outputCols, value) + + override def transform(dataset: Dataset[_]): DataFrame = { + transformSchema(dataset.schema, logging = true) + var outputDF = dataset + val surrogates = surrogateDF.select($(inputCols).map(col): _*).head().toSeq + + $(inputCols).zip($(outputCols)).zip(surrogates).foreach { + case ((inputCol, outputCol), surrogate) => + val inputType = dataset.schema(inputCol).dataType + val ic = col(inputCol) + outputDF = outputDF.withColumn(outputCol, + when(ic.isNull, surrogate) + .when(ic === $(missingValue), surrogate) + .otherwise(ic) + .cast(inputType)) + } + outputDF.toDF() + } + + override def transformSchema(schema: StructType): StructType = { + validateAndTransformSchema(schema) + } + + override def copy(extra: ParamMap): ImputerModel = { + val copied = new ImputerModel(uid, surrogateDF) + copyValues(copied, extra).setParent(parent) + } + + @Since("2.2.0") + override def write: MLWriter = new ImputerModelWriter(this) +} + + +@Since("2.2.0") +object ImputerModel extends MLReadable[ImputerModel] { + + private[ImputerModel] class ImputerModelWriter(instance: ImputerModel) extends MLWriter { + + override protected def saveImpl(path: String): Unit = { + DefaultParamsWriter.saveMetadata(instance, path, sc) + val dataPath = new Path(path, "data").toString + instance.surrogateDF.repartition(1).write.parquet(dataPath) + } + } + + private class ImputerReader extends MLReader[ImputerModel] { + + private val className = classOf[ImputerModel].getName + + override def load(path: String): ImputerModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val dataPath = new Path(path, "data").toString + val surrogateDF = sqlContext.read.parquet(dataPath) + val model = new ImputerModel(metadata.uid, surrogateDF) + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } + + @Since("2.2.0") + override def read: MLReader[ImputerModel] = new ImputerReader + + @Since("2.2.0") + override def load(path: String): ImputerModel = super.load(path) +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/ImputerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/ImputerSuite.scala new file mode 100644 index 000000000000..ee2ba73fa96d --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/ImputerSuite.scala @@ -0,0 +1,185 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.ml.feature + +import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.sql.{DataFrame, Row} + +class ImputerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + + test("Imputer for Double with default missing Value NaN") { + val df = spark.createDataFrame( Seq( + (0, 1.0, 4.0, 1.0, 1.0, 4.0, 4.0), + (1, 11.0, 12.0, 11.0, 11.0, 12.0, 12.0), + (2, 3.0, Double.NaN, 3.0, 3.0, 10.0, 12.0), + (3, Double.NaN, 14.0, 5.0, 3.0, 14.0, 14.0) + )).toDF("id", "value1", "value2", "expected_mean_value1", "expected_median_value1", + "expected_mean_value2", "expected_median_value2") + val imputer = new Imputer() + .setInputCols(Array("value1", "value2")) + .setOutputCols(Array("out1", "out2")) + ImputerSuite.iterateStrategyTest(imputer, df) + } + + test("Imputer should handle NaNs when computing surrogate value, if missingValue is not NaN") { + val df = spark.createDataFrame( Seq( + (0, 1.0, 1.0, 1.0), + (1, 3.0, 3.0, 3.0), + (2, Double.NaN, Double.NaN, Double.NaN), + (3, -1.0, 2.0, 3.0) + )).toDF("id", "value", "expected_mean_value", "expected_median_value") + val imputer = new Imputer().setInputCols(Array("value")).setOutputCols(Array("out")) + .setMissingValue(-1.0) + ImputerSuite.iterateStrategyTest(imputer, df) + } + + test("Imputer for Float with missing Value -1.0") { + val df = spark.createDataFrame( Seq( + (0, 1.0F, 1.0F, 1.0F), + (1, 3.0F, 3.0F, 3.0F), + (2, 10.0F, 10.0F, 10.0F), + (3, 10.0F, 10.0F, 10.0F), + (4, -1.0F, 6.0F, 3.0F) + )).toDF("id", "value", "expected_mean_value", "expected_median_value") + val imputer = new Imputer().setInputCols(Array("value")).setOutputCols(Array("out")) + .setMissingValue(-1) + ImputerSuite.iterateStrategyTest(imputer, df) + } + + test("Imputer should impute null as well as 'missingValue'") { + val rawDf = spark.createDataFrame( Seq( + (0, 4.0, 4.0, 4.0), + (1, 10.0, 10.0, 10.0), + (2, 10.0, 10.0, 10.0), + (3, Double.NaN, 8.0, 10.0), + (4, -1.0, 8.0, 10.0) + )).toDF("id", "rawValue", "expected_mean_value", "expected_median_value") + val df = rawDf.selectExpr("*", "IF(rawValue=-1.0, null, rawValue) as value") + val imputer = new Imputer().setInputCols(Array("value")).setOutputCols(Array("out")) + ImputerSuite.iterateStrategyTest(imputer, df) + } + + test("Imputer throws exception when surrogate cannot be computed") { + val df = spark.createDataFrame( Seq( + (0, Double.NaN, 1.0, 1.0), + (1, Double.NaN, 3.0, 3.0), + (2, Double.NaN, Double.NaN, Double.NaN) + )).toDF("id", "value", "expected_mean_value", "expected_median_value") + Seq("mean", "median").foreach { strategy => + val imputer = new Imputer().setInputCols(Array("value")).setOutputCols(Array("out")) + .setStrategy(strategy) + withClue("Imputer should fail all the values are invalid") { + val e: SparkException = intercept[SparkException] { + val model = imputer.fit(df) + } + assert(e.getMessage.contains("surrogate cannot be computed")) + } + } + } + + test("Imputer input & output column validation") { + val df = spark.createDataFrame( Seq( + (0, 1.0, 1.0, 1.0), + (1, Double.NaN, 3.0, 3.0), + (2, Double.NaN, Double.NaN, Double.NaN) + )).toDF("id", "value1", "value2", "value3") + Seq("mean", "median").foreach { strategy => + withClue("Imputer should fail if inputCols and outputCols are different length") { + val e: IllegalArgumentException = intercept[IllegalArgumentException] { + val imputer = new Imputer().setStrategy(strategy) + .setInputCols(Array("value1", "value2")) + .setOutputCols(Array("out1")) + val model = imputer.fit(df) + } + assert(e.getMessage.contains("should have the same length")) + } + + withClue("Imputer should fail if inputCols contains duplicates") { + val e: IllegalArgumentException = intercept[IllegalArgumentException] { + val imputer = new Imputer().setStrategy(strategy) + .setInputCols(Array("value1", "value1")) + .setOutputCols(Array("out1", "out2")) + val model = imputer.fit(df) + } + assert(e.getMessage.contains("inputCols contains duplicates")) + } + + withClue("Imputer should fail if outputCols contains duplicates") { + val e: IllegalArgumentException = intercept[IllegalArgumentException] { + val imputer = new Imputer().setStrategy(strategy) + .setInputCols(Array("value1", "value2")) + .setOutputCols(Array("out1", "out1")) + val model = imputer.fit(df) + } + assert(e.getMessage.contains("outputCols contains duplicates")) + } + } + } + + test("Imputer read/write") { + val t = new Imputer() + .setInputCols(Array("myInputCol")) + .setOutputCols(Array("myOutputCol")) + .setMissingValue(-1.0) + testDefaultReadWrite(t) + } + + test("ImputerModel read/write") { + val spark = this.spark + import spark.implicits._ + val surrogateDF = Seq(1.234).toDF("myInputCol") + + val instance = new ImputerModel( + "myImputer", surrogateDF) + .setInputCols(Array("myInputCol")) + .setOutputCols(Array("myOutputCol")) + val newInstance = testDefaultReadWrite(instance) + assert(newInstance.surrogateDF.columns === instance.surrogateDF.columns) + assert(newInstance.surrogateDF.collect() === instance.surrogateDF.collect()) + } + +} + +object ImputerSuite { + + /** + * Imputation strategy. Available options are ["mean", "median"]. + * @param df DataFrame with columns "id", "value", "expected_mean", "expected_median" + */ + def iterateStrategyTest(imputer: Imputer, df: DataFrame): Unit = { + val inputCols = imputer.getInputCols + + Seq("mean", "median").foreach { strategy => + imputer.setStrategy(strategy) + val model = imputer.fit(df) + val resultDF = model.transform(df) + imputer.getInputCols.zip(imputer.getOutputCols).foreach { case (inputCol, outputCol) => + resultDF.select(s"expected_${strategy}_$inputCol", outputCol).collect().foreach { + case Row(exp: Float, out: Float) => + assert((exp.isNaN && out.isNaN) || (exp == out), + s"Imputed values differ. Expected: $exp, actual: $out") + case Row(exp: Double, out: Double) => + assert((exp.isNaN && out.isNaN) || (exp ~== out absTol 1e-5), + s"Imputed values differ. Expected: $exp, actual: $out") + } + } + } + } +} From ee91a0decc389572099ea7c038149cc50375a2ef Mon Sep 17 00:00:00 2001 From: Bogdan Raducanu Date: Thu, 16 Mar 2017 15:25:45 +0100 Subject: [PATCH 051/512] [SPARK-19946][TESTING] DebugFilesystem.assertNoOpenStreams should report the open streams to help debugging ## What changes were proposed in this pull request? DebugFilesystem.assertNoOpenStreams throws an exception with a cause exception that actually shows the code line which leaked the stream. ## How was this patch tested? New test in SparkContextSuite to check there is a cause exception. Author: Bogdan Raducanu Closes #17292 from bogdanrdc/SPARK-19946. --- .../org/apache/spark/DebugFilesystem.scala | 3 ++- .../org/apache/spark/SparkContextSuite.scala | 20 ++++++++++++++++++- 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/DebugFilesystem.scala b/core/src/test/scala/org/apache/spark/DebugFilesystem.scala index fb8d701ebda8..72aea841117c 100644 --- a/core/src/test/scala/org/apache/spark/DebugFilesystem.scala +++ b/core/src/test/scala/org/apache/spark/DebugFilesystem.scala @@ -44,7 +44,8 @@ object DebugFilesystem extends Logging { logWarning("Leaked filesystem connection created at:") exc.printStackTrace() } - throw new RuntimeException(s"There are $numOpen possibly leaked file streams.") + throw new IllegalStateException(s"There are $numOpen possibly leaked file streams.", + openStreams.values().asScala.head) } } } diff --git a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala index f97a112ec127..d08a162feda0 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark import java.io.File -import java.net.MalformedURLException +import java.net.{MalformedURLException, URI} import java.nio.charset.StandardCharsets import java.util.concurrent.TimeUnit @@ -26,6 +26,8 @@ import scala.concurrent.duration._ import scala.concurrent.Await import com.google.common.io.Files +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.io.{BytesWritable, LongWritable, Text} import org.apache.hadoop.mapred.TextInputFormat import org.apache.hadoop.mapreduce.lib.input.{TextInputFormat => NewTextInputFormat} @@ -538,6 +540,22 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu } } + test("SPARK-19446: DebugFilesystem.assertNoOpenStreams should report " + + "open streams to help debugging") { + val fs = new DebugFilesystem() + fs.initialize(new URI("file:///"), new Configuration()) + val file = File.createTempFile("SPARK19446", "temp") + Files.write(Array.ofDim[Byte](1000), file) + val path = new Path("file:///" + file.getCanonicalPath) + val stream = fs.open(path) + val exc = intercept[RuntimeException] { + DebugFilesystem.assertNoOpenStreams() + } + assert(exc != null) + assert(exc.getCause() != null) + stream.close() + } + } object SparkContextSuite { From 8e8f898335f5019c0d4f3944c4aefa12a185db70 Mon Sep 17 00:00:00 2001 From: windpiger Date: Thu, 16 Mar 2017 11:34:13 -0700 Subject: [PATCH 052/512] [SPARK-19945][SQL] add test suite for SessionCatalog with HiveExternalCatalog ## What changes were proposed in this pull request? Currently `SessionCatalogSuite` is only for `InMemoryCatalog`, there is no suite for `HiveExternalCatalog`. And there are some ddl function is not proper to test in `ExternalCatalogSuite`, because some logic are not full implement in `ExternalCatalog`, these ddl functions are full implement in `SessionCatalog`(e.g. merge the same logic from `ExternalCatalog` up to `SessionCatalog` ). It is better to test it in `SessionCatalogSuite` for this situation. So we should add a test suite for `SessionCatalog` with `HiveExternalCatalog` The main change is that in `SessionCatalogSuite` add two functions: `withBasicCatalog` and `withEmptyCatalog` And replace the code like `val catalog = new SessionCatalog(newBasicCatalog)` with above two functions ## How was this patch tested? add `HiveExternalSessionCatalogSuite` Author: windpiger Closes #17287 from windpiger/sessioncatalogsuit. --- .../sql/catalyst/catalog/SessionCatalog.scala | 2 +- .../catalog/SessionCatalogSuite.scala | 1907 +++++++++-------- .../HiveExternalSessionCatalogSuite.scala | 40 + 3 files changed, 1049 insertions(+), 900 deletions(-) create mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalSessionCatalogSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index bfcdb70fe47c..25aa8d3ba921 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -48,7 +48,7 @@ object SessionCatalog { * This class must be thread-safe. */ class SessionCatalog( - externalCatalog: ExternalCatalog, + val externalCatalog: ExternalCatalog, globalTempViewManager: GlobalTempViewManager, functionRegistry: FunctionRegistry, conf: CatalystConf, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala index 7e74dcdef0e2..bb87763e0bbb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala @@ -27,41 +27,67 @@ import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{Range, SubqueryAlias, View} +class InMemorySessionCatalogSuite extends SessionCatalogSuite { + protected val utils = new CatalogTestUtils { + override val tableInputFormat: String = "com.fruit.eyephone.CameraInputFormat" + override val tableOutputFormat: String = "com.fruit.eyephone.CameraOutputFormat" + override val defaultProvider: String = "parquet" + override def newEmptyCatalog(): ExternalCatalog = new InMemoryCatalog + } +} + /** - * Tests for [[SessionCatalog]] that assume that [[InMemoryCatalog]] is correctly implemented. + * Tests for [[SessionCatalog]] * * Note: many of the methods here are very similar to the ones in [[ExternalCatalogSuite]]. * This is because [[SessionCatalog]] and [[ExternalCatalog]] share many similar method * signatures but do not extend a common parent. This is largely by design but * unfortunately leads to very similar test code in two places. */ -class SessionCatalogSuite extends PlanTest { - private val utils = new CatalogTestUtils { - override val tableInputFormat: String = "com.fruit.eyephone.CameraInputFormat" - override val tableOutputFormat: String = "com.fruit.eyephone.CameraOutputFormat" - override val defaultProvider: String = "parquet" - override def newEmptyCatalog(): ExternalCatalog = new InMemoryCatalog - } +abstract class SessionCatalogSuite extends PlanTest { + protected val utils: CatalogTestUtils + + protected val isHiveExternalCatalog = false import utils._ + private def withBasicCatalog(f: SessionCatalog => Unit): Unit = { + val catalog = new SessionCatalog(newBasicCatalog()) + catalog.createDatabase(newDb("default"), ignoreIfExists = true) + try { + f(catalog) + } finally { + catalog.reset() + } + } + + private def withEmptyCatalog(f: SessionCatalog => Unit): Unit = { + val catalog = new SessionCatalog(newEmptyCatalog()) + catalog.createDatabase(newDb("default"), ignoreIfExists = true) + try { + f(catalog) + } finally { + catalog.reset() + } + } // -------------------------------------------------------------------------- // Databases // -------------------------------------------------------------------------- test("basic create and list databases") { - val catalog = new SessionCatalog(newEmptyCatalog()) - catalog.createDatabase(newDb("default"), ignoreIfExists = true) - assert(catalog.databaseExists("default")) - assert(!catalog.databaseExists("testing")) - assert(!catalog.databaseExists("testing2")) - catalog.createDatabase(newDb("testing"), ignoreIfExists = false) - assert(catalog.databaseExists("testing")) - assert(catalog.listDatabases().toSet == Set("default", "testing")) - catalog.createDatabase(newDb("testing2"), ignoreIfExists = false) - assert(catalog.listDatabases().toSet == Set("default", "testing", "testing2")) - assert(catalog.databaseExists("testing2")) - assert(!catalog.databaseExists("does_not_exist")) + withEmptyCatalog { catalog => + catalog.createDatabase(newDb("default"), ignoreIfExists = true) + assert(catalog.databaseExists("default")) + assert(!catalog.databaseExists("testing")) + assert(!catalog.databaseExists("testing2")) + catalog.createDatabase(newDb("testing"), ignoreIfExists = false) + assert(catalog.databaseExists("testing")) + assert(catalog.listDatabases().toSet == Set("default", "testing")) + catalog.createDatabase(newDb("testing2"), ignoreIfExists = false) + assert(catalog.listDatabases().toSet == Set("default", "testing", "testing2")) + assert(catalog.databaseExists("testing2")) + assert(!catalog.databaseExists("does_not_exist")) + } } def testInvalidName(func: (String) => Unit) { @@ -76,121 +102,141 @@ class SessionCatalogSuite extends PlanTest { } test("create databases using invalid names") { - val catalog = new SessionCatalog(newEmptyCatalog()) - testInvalidName(name => catalog.createDatabase(newDb(name), ignoreIfExists = true)) + withEmptyCatalog { catalog => + testInvalidName( + name => catalog.createDatabase(newDb(name), ignoreIfExists = true)) + } } test("get database when a database exists") { - val catalog = new SessionCatalog(newBasicCatalog()) - val db1 = catalog.getDatabaseMetadata("db1") - assert(db1.name == "db1") - assert(db1.description.contains("db1")) + withBasicCatalog { catalog => + val db1 = catalog.getDatabaseMetadata("db1") + assert(db1.name == "db1") + assert(db1.description.contains("db1")) + } } test("get database should throw exception when the database does not exist") { - val catalog = new SessionCatalog(newBasicCatalog()) - intercept[NoSuchDatabaseException] { - catalog.getDatabaseMetadata("db_that_does_not_exist") + withBasicCatalog { catalog => + intercept[NoSuchDatabaseException] { + catalog.getDatabaseMetadata("db_that_does_not_exist") + } } } test("list databases without pattern") { - val catalog = new SessionCatalog(newBasicCatalog()) - assert(catalog.listDatabases().toSet == Set("default", "db1", "db2", "db3")) + withBasicCatalog { catalog => + assert(catalog.listDatabases().toSet == Set("default", "db1", "db2", "db3")) + } } test("list databases with pattern") { - val catalog = new SessionCatalog(newBasicCatalog()) - assert(catalog.listDatabases("db").toSet == Set.empty) - assert(catalog.listDatabases("db*").toSet == Set("db1", "db2", "db3")) - assert(catalog.listDatabases("*1").toSet == Set("db1")) - assert(catalog.listDatabases("db2").toSet == Set("db2")) + withBasicCatalog { catalog => + assert(catalog.listDatabases("db").toSet == Set.empty) + assert(catalog.listDatabases("db*").toSet == Set("db1", "db2", "db3")) + assert(catalog.listDatabases("*1").toSet == Set("db1")) + assert(catalog.listDatabases("db2").toSet == Set("db2")) + } } test("drop database") { - val catalog = new SessionCatalog(newBasicCatalog()) - catalog.dropDatabase("db1", ignoreIfNotExists = false, cascade = false) - assert(catalog.listDatabases().toSet == Set("default", "db2", "db3")) + withBasicCatalog { catalog => + catalog.dropDatabase("db1", ignoreIfNotExists = false, cascade = false) + assert(catalog.listDatabases().toSet == Set("default", "db2", "db3")) + } } test("drop database when the database is not empty") { // Throw exception if there are functions left - val externalCatalog1 = newBasicCatalog() - val sessionCatalog1 = new SessionCatalog(externalCatalog1) - externalCatalog1.dropTable("db2", "tbl1", ignoreIfNotExists = false, purge = false) - externalCatalog1.dropTable("db2", "tbl2", ignoreIfNotExists = false, purge = false) - intercept[AnalysisException] { - sessionCatalog1.dropDatabase("db2", ignoreIfNotExists = false, cascade = false) + withBasicCatalog { catalog => + catalog.externalCatalog.dropTable("db2", "tbl1", ignoreIfNotExists = false, purge = false) + catalog.externalCatalog.dropTable("db2", "tbl2", ignoreIfNotExists = false, purge = false) + intercept[AnalysisException] { + catalog.dropDatabase("db2", ignoreIfNotExists = false, cascade = false) + } } - - // Throw exception if there are tables left - val externalCatalog2 = newBasicCatalog() - val sessionCatalog2 = new SessionCatalog(externalCatalog2) - externalCatalog2.dropFunction("db2", "func1") - intercept[AnalysisException] { - sessionCatalog2.dropDatabase("db2", ignoreIfNotExists = false, cascade = false) + withBasicCatalog { catalog => + // Throw exception if there are tables left + catalog.externalCatalog.dropFunction("db2", "func1") + intercept[AnalysisException] { + catalog.dropDatabase("db2", ignoreIfNotExists = false, cascade = false) + } } - // When cascade is true, it should drop them - val externalCatalog3 = newBasicCatalog() - val sessionCatalog3 = new SessionCatalog(externalCatalog3) - externalCatalog3.dropDatabase("db2", ignoreIfNotExists = false, cascade = true) - assert(sessionCatalog3.listDatabases().toSet == Set("default", "db1", "db3")) + withBasicCatalog { catalog => + // When cascade is true, it should drop them + catalog.externalCatalog.dropDatabase("db2", ignoreIfNotExists = false, cascade = true) + assert(catalog.listDatabases().toSet == Set("default", "db1", "db3")) + } } test("drop database when the database does not exist") { - val catalog = new SessionCatalog(newBasicCatalog()) - intercept[NoSuchDatabaseException] { - catalog.dropDatabase("db_that_does_not_exist", ignoreIfNotExists = false, cascade = false) + withBasicCatalog { catalog => + // TODO: fix this inconsistent between HiveExternalCatalog and InMemoryCatalog + if (isHiveExternalCatalog) { + val e = intercept[AnalysisException] { + catalog.dropDatabase("db_that_does_not_exist", ignoreIfNotExists = false, cascade = false) + }.getMessage + assert(e.contains( + "org.apache.hadoop.hive.metastore.api.NoSuchObjectException: db_that_does_not_exist")) + } else { + intercept[NoSuchDatabaseException] { + catalog.dropDatabase("db_that_does_not_exist", ignoreIfNotExists = false, cascade = false) + } + } + catalog.dropDatabase("db_that_does_not_exist", ignoreIfNotExists = true, cascade = false) } - catalog.dropDatabase("db_that_does_not_exist", ignoreIfNotExists = true, cascade = false) } test("drop current database and drop default database") { - val catalog = new SessionCatalog(newBasicCatalog()) - catalog.setCurrentDatabase("db1") - assert(catalog.getCurrentDatabase == "db1") - catalog.dropDatabase("db1", ignoreIfNotExists = false, cascade = true) - intercept[NoSuchDatabaseException] { - catalog.createTable(newTable("tbl1", "db1"), ignoreIfExists = false) - } - catalog.setCurrentDatabase("default") - assert(catalog.getCurrentDatabase == "default") - intercept[AnalysisException] { - catalog.dropDatabase("default", ignoreIfNotExists = false, cascade = true) + withBasicCatalog { catalog => + catalog.setCurrentDatabase("db1") + assert(catalog.getCurrentDatabase == "db1") + catalog.dropDatabase("db1", ignoreIfNotExists = false, cascade = true) + intercept[NoSuchDatabaseException] { + catalog.createTable(newTable("tbl1", "db1"), ignoreIfExists = false) + } + catalog.setCurrentDatabase("default") + assert(catalog.getCurrentDatabase == "default") + intercept[AnalysisException] { + catalog.dropDatabase("default", ignoreIfNotExists = false, cascade = true) + } } } test("alter database") { - val catalog = new SessionCatalog(newBasicCatalog()) - val db1 = catalog.getDatabaseMetadata("db1") - // Note: alter properties here because Hive does not support altering other fields - catalog.alterDatabase(db1.copy(properties = Map("k" -> "v3", "good" -> "true"))) - val newDb1 = catalog.getDatabaseMetadata("db1") - assert(db1.properties.isEmpty) - assert(newDb1.properties.size == 2) - assert(newDb1.properties.get("k") == Some("v3")) - assert(newDb1.properties.get("good") == Some("true")) + withBasicCatalog { catalog => + val db1 = catalog.getDatabaseMetadata("db1") + // Note: alter properties here because Hive does not support altering other fields + catalog.alterDatabase(db1.copy(properties = Map("k" -> "v3", "good" -> "true"))) + val newDb1 = catalog.getDatabaseMetadata("db1") + assert(db1.properties.isEmpty) + assert(newDb1.properties.size == 2) + assert(newDb1.properties.get("k") == Some("v3")) + assert(newDb1.properties.get("good") == Some("true")) + } } test("alter database should throw exception when the database does not exist") { - val catalog = new SessionCatalog(newBasicCatalog()) - intercept[NoSuchDatabaseException] { - catalog.alterDatabase(newDb("unknown_db")) + withBasicCatalog { catalog => + intercept[NoSuchDatabaseException] { + catalog.alterDatabase(newDb("unknown_db")) + } } } test("get/set current database") { - val catalog = new SessionCatalog(newBasicCatalog()) - assert(catalog.getCurrentDatabase == "default") - catalog.setCurrentDatabase("db2") - assert(catalog.getCurrentDatabase == "db2") - intercept[NoSuchDatabaseException] { + withBasicCatalog { catalog => + assert(catalog.getCurrentDatabase == "default") + catalog.setCurrentDatabase("db2") + assert(catalog.getCurrentDatabase == "db2") + intercept[NoSuchDatabaseException] { + catalog.setCurrentDatabase("deebo") + } + catalog.createDatabase(newDb("deebo"), ignoreIfExists = false) catalog.setCurrentDatabase("deebo") + assert(catalog.getCurrentDatabase == "deebo") } - catalog.createDatabase(newDb("deebo"), ignoreIfExists = false) - catalog.setCurrentDatabase("deebo") - assert(catalog.getCurrentDatabase == "deebo") } // -------------------------------------------------------------------------- @@ -198,346 +244,360 @@ class SessionCatalogSuite extends PlanTest { // -------------------------------------------------------------------------- test("create table") { - val externalCatalog = newBasicCatalog() - val sessionCatalog = new SessionCatalog(externalCatalog) - assert(externalCatalog.listTables("db1").isEmpty) - assert(externalCatalog.listTables("db2").toSet == Set("tbl1", "tbl2")) - sessionCatalog.createTable(newTable("tbl3", "db1"), ignoreIfExists = false) - sessionCatalog.createTable(newTable("tbl3", "db2"), ignoreIfExists = false) - assert(externalCatalog.listTables("db1").toSet == Set("tbl3")) - assert(externalCatalog.listTables("db2").toSet == Set("tbl1", "tbl2", "tbl3")) - // Create table without explicitly specifying database - sessionCatalog.setCurrentDatabase("db1") - sessionCatalog.createTable(newTable("tbl4"), ignoreIfExists = false) - assert(externalCatalog.listTables("db1").toSet == Set("tbl3", "tbl4")) - assert(externalCatalog.listTables("db2").toSet == Set("tbl1", "tbl2", "tbl3")) + withBasicCatalog { catalog => + assert(catalog.externalCatalog.listTables("db1").isEmpty) + assert(catalog.externalCatalog.listTables("db2").toSet == Set("tbl1", "tbl2")) + catalog.createTable(newTable("tbl3", "db1"), ignoreIfExists = false) + catalog.createTable(newTable("tbl3", "db2"), ignoreIfExists = false) + assert(catalog.externalCatalog.listTables("db1").toSet == Set("tbl3")) + assert(catalog.externalCatalog.listTables("db2").toSet == Set("tbl1", "tbl2", "tbl3")) + // Create table without explicitly specifying database + catalog.setCurrentDatabase("db1") + catalog.createTable(newTable("tbl4"), ignoreIfExists = false) + assert(catalog.externalCatalog.listTables("db1").toSet == Set("tbl3", "tbl4")) + assert(catalog.externalCatalog.listTables("db2").toSet == Set("tbl1", "tbl2", "tbl3")) + } } test("create tables using invalid names") { - val catalog = new SessionCatalog(newEmptyCatalog()) - testInvalidName(name => catalog.createTable(newTable(name, "db1"), ignoreIfExists = false)) + withEmptyCatalog { catalog => + testInvalidName(name => catalog.createTable(newTable(name, "db1"), ignoreIfExists = false)) + } } test("create table when database does not exist") { - val catalog = new SessionCatalog(newBasicCatalog()) - // Creating table in non-existent database should always fail - intercept[NoSuchDatabaseException] { - catalog.createTable(newTable("tbl1", "does_not_exist"), ignoreIfExists = false) - } - intercept[NoSuchDatabaseException] { - catalog.createTable(newTable("tbl1", "does_not_exist"), ignoreIfExists = true) - } - // Table already exists - intercept[TableAlreadyExistsException] { - catalog.createTable(newTable("tbl1", "db2"), ignoreIfExists = false) + withBasicCatalog { catalog => + // Creating table in non-existent database should always fail + intercept[NoSuchDatabaseException] { + catalog.createTable(newTable("tbl1", "does_not_exist"), ignoreIfExists = false) + } + intercept[NoSuchDatabaseException] { + catalog.createTable(newTable("tbl1", "does_not_exist"), ignoreIfExists = true) + } + // Table already exists + intercept[TableAlreadyExistsException] { + catalog.createTable(newTable("tbl1", "db2"), ignoreIfExists = false) + } + catalog.createTable(newTable("tbl1", "db2"), ignoreIfExists = true) } - catalog.createTable(newTable("tbl1", "db2"), ignoreIfExists = true) } test("create temp table") { - val catalog = new SessionCatalog(newBasicCatalog()) - val tempTable1 = Range(1, 10, 1, 10) - val tempTable2 = Range(1, 20, 2, 10) - catalog.createTempView("tbl1", tempTable1, overrideIfExists = false) - catalog.createTempView("tbl2", tempTable2, overrideIfExists = false) - assert(catalog.getTempView("tbl1") == Option(tempTable1)) - assert(catalog.getTempView("tbl2") == Option(tempTable2)) - assert(catalog.getTempView("tbl3").isEmpty) - // Temporary table already exists - intercept[TempTableAlreadyExistsException] { + withBasicCatalog { catalog => + val tempTable1 = Range(1, 10, 1, 10) + val tempTable2 = Range(1, 20, 2, 10) catalog.createTempView("tbl1", tempTable1, overrideIfExists = false) + catalog.createTempView("tbl2", tempTable2, overrideIfExists = false) + assert(catalog.getTempView("tbl1") == Option(tempTable1)) + assert(catalog.getTempView("tbl2") == Option(tempTable2)) + assert(catalog.getTempView("tbl3").isEmpty) + // Temporary table already exists + intercept[TempTableAlreadyExistsException] { + catalog.createTempView("tbl1", tempTable1, overrideIfExists = false) + } + // Temporary table already exists but we override it + catalog.createTempView("tbl1", tempTable2, overrideIfExists = true) + assert(catalog.getTempView("tbl1") == Option(tempTable2)) } - // Temporary table already exists but we override it - catalog.createTempView("tbl1", tempTable2, overrideIfExists = true) - assert(catalog.getTempView("tbl1") == Option(tempTable2)) } test("drop table") { - val externalCatalog = newBasicCatalog() - val sessionCatalog = new SessionCatalog(externalCatalog) - assert(externalCatalog.listTables("db2").toSet == Set("tbl1", "tbl2")) - sessionCatalog.dropTable(TableIdentifier("tbl1", Some("db2")), ignoreIfNotExists = false, - purge = false) - assert(externalCatalog.listTables("db2").toSet == Set("tbl2")) - // Drop table without explicitly specifying database - sessionCatalog.setCurrentDatabase("db2") - sessionCatalog.dropTable(TableIdentifier("tbl2"), ignoreIfNotExists = false, purge = false) - assert(externalCatalog.listTables("db2").isEmpty) + withBasicCatalog { catalog => + assert(catalog.externalCatalog.listTables("db2").toSet == Set("tbl1", "tbl2")) + catalog.dropTable(TableIdentifier("tbl1", Some("db2")), ignoreIfNotExists = false, + purge = false) + assert(catalog.externalCatalog.listTables("db2").toSet == Set("tbl2")) + // Drop table without explicitly specifying database + catalog.setCurrentDatabase("db2") + catalog.dropTable(TableIdentifier("tbl2"), ignoreIfNotExists = false, purge = false) + assert(catalog.externalCatalog.listTables("db2").isEmpty) + } } test("drop table when database/table does not exist") { - val catalog = new SessionCatalog(newBasicCatalog()) - // Should always throw exception when the database does not exist - intercept[NoSuchDatabaseException] { - catalog.dropTable(TableIdentifier("tbl1", Some("unknown_db")), ignoreIfNotExists = false, - purge = false) - } - intercept[NoSuchDatabaseException] { - catalog.dropTable(TableIdentifier("tbl1", Some("unknown_db")), ignoreIfNotExists = true, - purge = false) - } - intercept[NoSuchTableException] { - catalog.dropTable(TableIdentifier("unknown_table", Some("db2")), ignoreIfNotExists = false, + withBasicCatalog { catalog => + // Should always throw exception when the database does not exist + intercept[NoSuchDatabaseException] { + catalog.dropTable(TableIdentifier("tbl1", Some("unknown_db")), ignoreIfNotExists = false, + purge = false) + } + intercept[NoSuchDatabaseException] { + catalog.dropTable(TableIdentifier("tbl1", Some("unknown_db")), ignoreIfNotExists = true, + purge = false) + } + intercept[NoSuchTableException] { + catalog.dropTable(TableIdentifier("unknown_table", Some("db2")), ignoreIfNotExists = false, + purge = false) + } + catalog.dropTable(TableIdentifier("unknown_table", Some("db2")), ignoreIfNotExists = true, purge = false) } - catalog.dropTable(TableIdentifier("unknown_table", Some("db2")), ignoreIfNotExists = true, - purge = false) } test("drop temp table") { - val externalCatalog = newBasicCatalog() - val sessionCatalog = new SessionCatalog(externalCatalog) - val tempTable = Range(1, 10, 2, 10) - sessionCatalog.createTempView("tbl1", tempTable, overrideIfExists = false) - sessionCatalog.setCurrentDatabase("db2") - assert(sessionCatalog.getTempView("tbl1") == Some(tempTable)) - assert(externalCatalog.listTables("db2").toSet == Set("tbl1", "tbl2")) - // If database is not specified, temp table should be dropped first - sessionCatalog.dropTable(TableIdentifier("tbl1"), ignoreIfNotExists = false, purge = false) - assert(sessionCatalog.getTempView("tbl1") == None) - assert(externalCatalog.listTables("db2").toSet == Set("tbl1", "tbl2")) - // If temp table does not exist, the table in the current database should be dropped - sessionCatalog.dropTable(TableIdentifier("tbl1"), ignoreIfNotExists = false, purge = false) - assert(externalCatalog.listTables("db2").toSet == Set("tbl2")) - // If database is specified, temp tables are never dropped - sessionCatalog.createTempView("tbl1", tempTable, overrideIfExists = false) - sessionCatalog.createTable(newTable("tbl1", "db2"), ignoreIfExists = false) - sessionCatalog.dropTable(TableIdentifier("tbl1", Some("db2")), ignoreIfNotExists = false, - purge = false) - assert(sessionCatalog.getTempView("tbl1") == Some(tempTable)) - assert(externalCatalog.listTables("db2").toSet == Set("tbl2")) + withBasicCatalog { catalog => + val tempTable = Range(1, 10, 2, 10) + catalog.createTempView("tbl1", tempTable, overrideIfExists = false) + catalog.setCurrentDatabase("db2") + assert(catalog.getTempView("tbl1") == Some(tempTable)) + assert(catalog.externalCatalog.listTables("db2").toSet == Set("tbl1", "tbl2")) + // If database is not specified, temp table should be dropped first + catalog.dropTable(TableIdentifier("tbl1"), ignoreIfNotExists = false, purge = false) + assert(catalog.getTempView("tbl1") == None) + assert(catalog.externalCatalog.listTables("db2").toSet == Set("tbl1", "tbl2")) + // If temp table does not exist, the table in the current database should be dropped + catalog.dropTable(TableIdentifier("tbl1"), ignoreIfNotExists = false, purge = false) + assert(catalog.externalCatalog.listTables("db2").toSet == Set("tbl2")) + // If database is specified, temp tables are never dropped + catalog.createTempView("tbl1", tempTable, overrideIfExists = false) + catalog.createTable(newTable("tbl1", "db2"), ignoreIfExists = false) + catalog.dropTable(TableIdentifier("tbl1", Some("db2")), ignoreIfNotExists = false, + purge = false) + assert(catalog.getTempView("tbl1") == Some(tempTable)) + assert(catalog.externalCatalog.listTables("db2").toSet == Set("tbl2")) + } } test("rename table") { - val externalCatalog = newBasicCatalog() - val sessionCatalog = new SessionCatalog(externalCatalog) - assert(externalCatalog.listTables("db2").toSet == Set("tbl1", "tbl2")) - sessionCatalog.renameTable(TableIdentifier("tbl1", Some("db2")), TableIdentifier("tblone")) - assert(externalCatalog.listTables("db2").toSet == Set("tblone", "tbl2")) - sessionCatalog.renameTable(TableIdentifier("tbl2", Some("db2")), TableIdentifier("tbltwo")) - assert(externalCatalog.listTables("db2").toSet == Set("tblone", "tbltwo")) - // Rename table without explicitly specifying database - sessionCatalog.setCurrentDatabase("db2") - sessionCatalog.renameTable(TableIdentifier("tbltwo"), TableIdentifier("table_two")) - assert(externalCatalog.listTables("db2").toSet == Set("tblone", "table_two")) - // Renaming "db2.tblone" to "db1.tblones" should fail because databases don't match - intercept[AnalysisException] { - sessionCatalog.renameTable( - TableIdentifier("tblone", Some("db2")), TableIdentifier("tblones", Some("db1"))) - } - // The new table already exists - intercept[TableAlreadyExistsException] { - sessionCatalog.renameTable( - TableIdentifier("tblone", Some("db2")), - TableIdentifier("table_two")) + withBasicCatalog { catalog => + assert(catalog.externalCatalog.listTables("db2").toSet == Set("tbl1", "tbl2")) + catalog.renameTable(TableIdentifier("tbl1", Some("db2")), TableIdentifier("tblone")) + assert(catalog.externalCatalog.listTables("db2").toSet == Set("tblone", "tbl2")) + catalog.renameTable(TableIdentifier("tbl2", Some("db2")), TableIdentifier("tbltwo")) + assert(catalog.externalCatalog.listTables("db2").toSet == Set("tblone", "tbltwo")) + // Rename table without explicitly specifying database + catalog.setCurrentDatabase("db2") + catalog.renameTable(TableIdentifier("tbltwo"), TableIdentifier("table_two")) + assert(catalog.externalCatalog.listTables("db2").toSet == Set("tblone", "table_two")) + // Renaming "db2.tblone" to "db1.tblones" should fail because databases don't match + intercept[AnalysisException] { + catalog.renameTable( + TableIdentifier("tblone", Some("db2")), TableIdentifier("tblones", Some("db1"))) + } + // The new table already exists + intercept[TableAlreadyExistsException] { + catalog.renameTable( + TableIdentifier("tblone", Some("db2")), + TableIdentifier("table_two")) + } } } test("rename tables to an invalid name") { - val catalog = new SessionCatalog(newBasicCatalog()) - testInvalidName( - name => catalog.renameTable(TableIdentifier("tbl1", Some("db2")), TableIdentifier(name))) + withBasicCatalog { catalog => + testInvalidName( + name => catalog.renameTable(TableIdentifier("tbl1", Some("db2")), TableIdentifier(name))) + } } test("rename table when database/table does not exist") { - val catalog = new SessionCatalog(newBasicCatalog()) - intercept[NoSuchDatabaseException] { - catalog.renameTable(TableIdentifier("tbl1", Some("unknown_db")), TableIdentifier("tbl2")) - } - intercept[NoSuchTableException] { - catalog.renameTable(TableIdentifier("unknown_table", Some("db2")), TableIdentifier("tbl2")) + withBasicCatalog { catalog => + intercept[NoSuchDatabaseException] { + catalog.renameTable(TableIdentifier("tbl1", Some("unknown_db")), TableIdentifier("tbl2")) + } + intercept[NoSuchTableException] { + catalog.renameTable(TableIdentifier("unknown_table", Some("db2")), TableIdentifier("tbl2")) + } } } test("rename temp table") { - val externalCatalog = newBasicCatalog() - val sessionCatalog = new SessionCatalog(externalCatalog) - val tempTable = Range(1, 10, 2, 10) - sessionCatalog.createTempView("tbl1", tempTable, overrideIfExists = false) - sessionCatalog.setCurrentDatabase("db2") - assert(sessionCatalog.getTempView("tbl1") == Option(tempTable)) - assert(externalCatalog.listTables("db2").toSet == Set("tbl1", "tbl2")) - // If database is not specified, temp table should be renamed first - sessionCatalog.renameTable(TableIdentifier("tbl1"), TableIdentifier("tbl3")) - assert(sessionCatalog.getTempView("tbl1").isEmpty) - assert(sessionCatalog.getTempView("tbl3") == Option(tempTable)) - assert(externalCatalog.listTables("db2").toSet == Set("tbl1", "tbl2")) - // If database is specified, temp tables are never renamed - sessionCatalog.renameTable(TableIdentifier("tbl2", Some("db2")), TableIdentifier("tbl4")) - assert(sessionCatalog.getTempView("tbl3") == Option(tempTable)) - assert(sessionCatalog.getTempView("tbl4").isEmpty) - assert(externalCatalog.listTables("db2").toSet == Set("tbl1", "tbl4")) + withBasicCatalog { catalog => + val tempTable = Range(1, 10, 2, 10) + catalog.createTempView("tbl1", tempTable, overrideIfExists = false) + catalog.setCurrentDatabase("db2") + assert(catalog.getTempView("tbl1") == Option(tempTable)) + assert(catalog.externalCatalog.listTables("db2").toSet == Set("tbl1", "tbl2")) + // If database is not specified, temp table should be renamed first + catalog.renameTable(TableIdentifier("tbl1"), TableIdentifier("tbl3")) + assert(catalog.getTempView("tbl1").isEmpty) + assert(catalog.getTempView("tbl3") == Option(tempTable)) + assert(catalog.externalCatalog.listTables("db2").toSet == Set("tbl1", "tbl2")) + // If database is specified, temp tables are never renamed + catalog.renameTable(TableIdentifier("tbl2", Some("db2")), TableIdentifier("tbl4")) + assert(catalog.getTempView("tbl3") == Option(tempTable)) + assert(catalog.getTempView("tbl4").isEmpty) + assert(catalog.externalCatalog.listTables("db2").toSet == Set("tbl1", "tbl4")) + } } test("alter table") { - val externalCatalog = newBasicCatalog() - val sessionCatalog = new SessionCatalog(externalCatalog) - val tbl1 = externalCatalog.getTable("db2", "tbl1") - sessionCatalog.alterTable(tbl1.copy(properties = Map("toh" -> "frem"))) - val newTbl1 = externalCatalog.getTable("db2", "tbl1") - assert(!tbl1.properties.contains("toh")) - assert(newTbl1.properties.size == tbl1.properties.size + 1) - assert(newTbl1.properties.get("toh") == Some("frem")) - // Alter table without explicitly specifying database - sessionCatalog.setCurrentDatabase("db2") - sessionCatalog.alterTable(tbl1.copy(identifier = TableIdentifier("tbl1"))) - val newestTbl1 = externalCatalog.getTable("db2", "tbl1") - assert(newestTbl1 == tbl1) + withBasicCatalog { catalog => + val tbl1 = catalog.externalCatalog.getTable("db2", "tbl1") + catalog.alterTable(tbl1.copy(properties = Map("toh" -> "frem"))) + val newTbl1 = catalog.externalCatalog.getTable("db2", "tbl1") + assert(!tbl1.properties.contains("toh")) + assert(newTbl1.properties.size == tbl1.properties.size + 1) + assert(newTbl1.properties.get("toh") == Some("frem")) + // Alter table without explicitly specifying database + catalog.setCurrentDatabase("db2") + catalog.alterTable(tbl1.copy(identifier = TableIdentifier("tbl1"))) + val newestTbl1 = catalog.externalCatalog.getTable("db2", "tbl1") + // For hive serde table, hive metastore will set transient_lastDdlTime in table's properties, + // and its value will be modified, here we ignore it when comparing the two tables. + assert(newestTbl1.copy(properties = Map.empty) == tbl1.copy(properties = Map.empty)) + } } test("alter table when database/table does not exist") { - val catalog = new SessionCatalog(newBasicCatalog()) - intercept[NoSuchDatabaseException] { - catalog.alterTable(newTable("tbl1", "unknown_db")) - } - intercept[NoSuchTableException] { - catalog.alterTable(newTable("unknown_table", "db2")) + withBasicCatalog { catalog => + intercept[NoSuchDatabaseException] { + catalog.alterTable(newTable("tbl1", "unknown_db")) + } + intercept[NoSuchTableException] { + catalog.alterTable(newTable("unknown_table", "db2")) + } } } test("get table") { - val externalCatalog = newBasicCatalog() - val sessionCatalog = new SessionCatalog(externalCatalog) - assert(sessionCatalog.getTableMetadata(TableIdentifier("tbl1", Some("db2"))) - == externalCatalog.getTable("db2", "tbl1")) - // Get table without explicitly specifying database - sessionCatalog.setCurrentDatabase("db2") - assert(sessionCatalog.getTableMetadata(TableIdentifier("tbl1")) - == externalCatalog.getTable("db2", "tbl1")) + withBasicCatalog { catalog => + assert(catalog.getTableMetadata(TableIdentifier("tbl1", Some("db2"))) + == catalog.externalCatalog.getTable("db2", "tbl1")) + // Get table without explicitly specifying database + catalog.setCurrentDatabase("db2") + assert(catalog.getTableMetadata(TableIdentifier("tbl1")) + == catalog.externalCatalog.getTable("db2", "tbl1")) + } } test("get table when database/table does not exist") { - val catalog = new SessionCatalog(newBasicCatalog()) - intercept[NoSuchDatabaseException] { - catalog.getTableMetadata(TableIdentifier("tbl1", Some("unknown_db"))) - } - intercept[NoSuchTableException] { - catalog.getTableMetadata(TableIdentifier("unknown_table", Some("db2"))) + withBasicCatalog { catalog => + intercept[NoSuchDatabaseException] { + catalog.getTableMetadata(TableIdentifier("tbl1", Some("unknown_db"))) + } + intercept[NoSuchTableException] { + catalog.getTableMetadata(TableIdentifier("unknown_table", Some("db2"))) + } } } test("get option of table metadata") { - val externalCatalog = newBasicCatalog() - val catalog = new SessionCatalog(externalCatalog) - assert(catalog.getTableMetadataOption(TableIdentifier("tbl1", Some("db2"))) - == Option(externalCatalog.getTable("db2", "tbl1"))) - assert(catalog.getTableMetadataOption(TableIdentifier("unknown_table", Some("db2"))).isEmpty) - intercept[NoSuchDatabaseException] { - catalog.getTableMetadataOption(TableIdentifier("tbl1", Some("unknown_db"))) + withBasicCatalog { catalog => + assert(catalog.getTableMetadataOption(TableIdentifier("tbl1", Some("db2"))) + == Option(catalog.externalCatalog.getTable("db2", "tbl1"))) + assert(catalog.getTableMetadataOption(TableIdentifier("unknown_table", Some("db2"))).isEmpty) + intercept[NoSuchDatabaseException] { + catalog.getTableMetadataOption(TableIdentifier("tbl1", Some("unknown_db"))) + } } } test("lookup table relation") { - val externalCatalog = newBasicCatalog() - val sessionCatalog = new SessionCatalog(externalCatalog) - val tempTable1 = Range(1, 10, 1, 10) - val metastoreTable1 = externalCatalog.getTable("db2", "tbl1") - sessionCatalog.createTempView("tbl1", tempTable1, overrideIfExists = false) - sessionCatalog.setCurrentDatabase("db2") - // If we explicitly specify the database, we'll look up the relation in that database - assert(sessionCatalog.lookupRelation(TableIdentifier("tbl1", Some("db2"))).children.head - .asInstanceOf[CatalogRelation].tableMeta == metastoreTable1) - // Otherwise, we'll first look up a temporary table with the same name - assert(sessionCatalog.lookupRelation(TableIdentifier("tbl1")) - == SubqueryAlias("tbl1", tempTable1)) - // Then, if that does not exist, look up the relation in the current database - sessionCatalog.dropTable(TableIdentifier("tbl1"), ignoreIfNotExists = false, purge = false) - assert(sessionCatalog.lookupRelation(TableIdentifier("tbl1")).children.head - .asInstanceOf[CatalogRelation].tableMeta == metastoreTable1) + withBasicCatalog { catalog => + val tempTable1 = Range(1, 10, 1, 10) + val metastoreTable1 = catalog.externalCatalog.getTable("db2", "tbl1") + catalog.createTempView("tbl1", tempTable1, overrideIfExists = false) + catalog.setCurrentDatabase("db2") + // If we explicitly specify the database, we'll look up the relation in that database + assert(catalog.lookupRelation(TableIdentifier("tbl1", Some("db2"))).children.head + .asInstanceOf[CatalogRelation].tableMeta == metastoreTable1) + // Otherwise, we'll first look up a temporary table with the same name + assert(catalog.lookupRelation(TableIdentifier("tbl1")) + == SubqueryAlias("tbl1", tempTable1)) + // Then, if that does not exist, look up the relation in the current database + catalog.dropTable(TableIdentifier("tbl1"), ignoreIfNotExists = false, purge = false) + assert(catalog.lookupRelation(TableIdentifier("tbl1")).children.head + .asInstanceOf[CatalogRelation].tableMeta == metastoreTable1) + } } test("look up view relation") { - val externalCatalog = newBasicCatalog() - val sessionCatalog = new SessionCatalog(externalCatalog) - val metadata = externalCatalog.getTable("db3", "view1") - sessionCatalog.setCurrentDatabase("default") - // Look up a view. - assert(metadata.viewText.isDefined) - val view = View(desc = metadata, output = metadata.schema.toAttributes, - child = CatalystSqlParser.parsePlan(metadata.viewText.get)) - comparePlans(sessionCatalog.lookupRelation(TableIdentifier("view1", Some("db3"))), - SubqueryAlias("view1", view)) - // Look up a view using current database of the session catalog. - sessionCatalog.setCurrentDatabase("db3") - comparePlans(sessionCatalog.lookupRelation(TableIdentifier("view1")), - SubqueryAlias("view1", view)) + withBasicCatalog { catalog => + val metadata = catalog.externalCatalog.getTable("db3", "view1") + catalog.setCurrentDatabase("default") + // Look up a view. + assert(metadata.viewText.isDefined) + val view = View(desc = metadata, output = metadata.schema.toAttributes, + child = CatalystSqlParser.parsePlan(metadata.viewText.get)) + comparePlans(catalog.lookupRelation(TableIdentifier("view1", Some("db3"))), + SubqueryAlias("view1", view)) + // Look up a view using current database of the session catalog. + catalog.setCurrentDatabase("db3") + comparePlans(catalog.lookupRelation(TableIdentifier("view1")), + SubqueryAlias("view1", view)) + } } test("table exists") { - val catalog = new SessionCatalog(newBasicCatalog()) - assert(catalog.tableExists(TableIdentifier("tbl1", Some("db2")))) - assert(catalog.tableExists(TableIdentifier("tbl2", Some("db2")))) - assert(!catalog.tableExists(TableIdentifier("tbl3", Some("db2")))) - assert(!catalog.tableExists(TableIdentifier("tbl1", Some("db1")))) - assert(!catalog.tableExists(TableIdentifier("tbl2", Some("db1")))) - // If database is explicitly specified, do not check temporary tables - val tempTable = Range(1, 10, 1, 10) - assert(!catalog.tableExists(TableIdentifier("tbl3", Some("db2")))) - // If database is not explicitly specified, check the current database - catalog.setCurrentDatabase("db2") - assert(catalog.tableExists(TableIdentifier("tbl1"))) - assert(catalog.tableExists(TableIdentifier("tbl2"))) - - catalog.createTempView("tbl3", tempTable, overrideIfExists = false) - // tableExists should not check temp view. - assert(!catalog.tableExists(TableIdentifier("tbl3"))) + withBasicCatalog { catalog => + assert(catalog.tableExists(TableIdentifier("tbl1", Some("db2")))) + assert(catalog.tableExists(TableIdentifier("tbl2", Some("db2")))) + assert(!catalog.tableExists(TableIdentifier("tbl3", Some("db2")))) + assert(!catalog.tableExists(TableIdentifier("tbl1", Some("db1")))) + assert(!catalog.tableExists(TableIdentifier("tbl2", Some("db1")))) + // If database is explicitly specified, do not check temporary tables + val tempTable = Range(1, 10, 1, 10) + assert(!catalog.tableExists(TableIdentifier("tbl3", Some("db2")))) + // If database is not explicitly specified, check the current database + catalog.setCurrentDatabase("db2") + assert(catalog.tableExists(TableIdentifier("tbl1"))) + assert(catalog.tableExists(TableIdentifier("tbl2"))) + + catalog.createTempView("tbl3", tempTable, overrideIfExists = false) + // tableExists should not check temp view. + assert(!catalog.tableExists(TableIdentifier("tbl3"))) + } } test("getTempViewOrPermanentTableMetadata on temporary views") { - val catalog = new SessionCatalog(newBasicCatalog()) - val tempTable = Range(1, 10, 2, 10) - intercept[NoSuchTableException] { - catalog.getTempViewOrPermanentTableMetadata(TableIdentifier("view1")) - }.getMessage + withBasicCatalog { catalog => + val tempTable = Range(1, 10, 2, 10) + intercept[NoSuchTableException] { + catalog.getTempViewOrPermanentTableMetadata(TableIdentifier("view1")) + }.getMessage - intercept[NoSuchTableException] { - catalog.getTempViewOrPermanentTableMetadata(TableIdentifier("view1", Some("default"))) - }.getMessage + intercept[NoSuchTableException] { + catalog.getTempViewOrPermanentTableMetadata(TableIdentifier("view1", Some("default"))) + }.getMessage - catalog.createTempView("view1", tempTable, overrideIfExists = false) - assert(catalog.getTempViewOrPermanentTableMetadata( - TableIdentifier("view1")).identifier.table == "view1") - assert(catalog.getTempViewOrPermanentTableMetadata( - TableIdentifier("view1")).schema(0).name == "id") + catalog.createTempView("view1", tempTable, overrideIfExists = false) + assert(catalog.getTempViewOrPermanentTableMetadata( + TableIdentifier("view1")).identifier.table == "view1") + assert(catalog.getTempViewOrPermanentTableMetadata( + TableIdentifier("view1")).schema(0).name == "id") - intercept[NoSuchTableException] { - catalog.getTempViewOrPermanentTableMetadata(TableIdentifier("view1", Some("default"))) - }.getMessage + intercept[NoSuchTableException] { + catalog.getTempViewOrPermanentTableMetadata(TableIdentifier("view1", Some("default"))) + }.getMessage + } } test("list tables without pattern") { - val catalog = new SessionCatalog(newBasicCatalog()) - val tempTable = Range(1, 10, 2, 10) - catalog.createTempView("tbl1", tempTable, overrideIfExists = false) - catalog.createTempView("tbl4", tempTable, overrideIfExists = false) - assert(catalog.listTables("db1").toSet == - Set(TableIdentifier("tbl1"), TableIdentifier("tbl4"))) - assert(catalog.listTables("db2").toSet == - Set(TableIdentifier("tbl1"), - TableIdentifier("tbl4"), - TableIdentifier("tbl1", Some("db2")), - TableIdentifier("tbl2", Some("db2")))) - intercept[NoSuchDatabaseException] { - catalog.listTables("unknown_db") + withBasicCatalog { catalog => + val tempTable = Range(1, 10, 2, 10) + catalog.createTempView("tbl1", tempTable, overrideIfExists = false) + catalog.createTempView("tbl4", tempTable, overrideIfExists = false) + assert(catalog.listTables("db1").toSet == + Set(TableIdentifier("tbl1"), TableIdentifier("tbl4"))) + assert(catalog.listTables("db2").toSet == + Set(TableIdentifier("tbl1"), + TableIdentifier("tbl4"), + TableIdentifier("tbl1", Some("db2")), + TableIdentifier("tbl2", Some("db2")))) + intercept[NoSuchDatabaseException] { + catalog.listTables("unknown_db") + } } } test("list tables with pattern") { - val catalog = new SessionCatalog(newBasicCatalog()) - val tempTable = Range(1, 10, 2, 10) - catalog.createTempView("tbl1", tempTable, overrideIfExists = false) - catalog.createTempView("tbl4", tempTable, overrideIfExists = false) - assert(catalog.listTables("db1", "*").toSet == catalog.listTables("db1").toSet) - assert(catalog.listTables("db2", "*").toSet == catalog.listTables("db2").toSet) - assert(catalog.listTables("db2", "tbl*").toSet == - Set(TableIdentifier("tbl1"), - TableIdentifier("tbl4"), - TableIdentifier("tbl1", Some("db2")), - TableIdentifier("tbl2", Some("db2")))) - assert(catalog.listTables("db2", "*1").toSet == - Set(TableIdentifier("tbl1"), TableIdentifier("tbl1", Some("db2")))) - intercept[NoSuchDatabaseException] { - catalog.listTables("unknown_db", "*") + withBasicCatalog { catalog => + val tempTable = Range(1, 10, 2, 10) + catalog.createTempView("tbl1", tempTable, overrideIfExists = false) + catalog.createTempView("tbl4", tempTable, overrideIfExists = false) + assert(catalog.listTables("db1", "*").toSet == catalog.listTables("db1").toSet) + assert(catalog.listTables("db2", "*").toSet == catalog.listTables("db2").toSet) + assert(catalog.listTables("db2", "tbl*").toSet == + Set(TableIdentifier("tbl1"), + TableIdentifier("tbl4"), + TableIdentifier("tbl1", Some("db2")), + TableIdentifier("tbl2", Some("db2")))) + assert(catalog.listTables("db2", "*1").toSet == + Set(TableIdentifier("tbl1"), TableIdentifier("tbl1", Some("db2")))) + intercept[NoSuchDatabaseException] { + catalog.listTables("unknown_db", "*") + } } } @@ -546,451 +606,477 @@ class SessionCatalogSuite extends PlanTest { // -------------------------------------------------------------------------- test("basic create and list partitions") { - val externalCatalog = newEmptyCatalog() - val sessionCatalog = new SessionCatalog(externalCatalog) - sessionCatalog.createDatabase(newDb("mydb"), ignoreIfExists = false) - sessionCatalog.createTable(newTable("tbl", "mydb"), ignoreIfExists = false) - sessionCatalog.createPartitions( - TableIdentifier("tbl", Some("mydb")), Seq(part1, part2), ignoreIfExists = false) - assert(catalogPartitionsEqual(externalCatalog.listPartitions("mydb", "tbl"), part1, part2)) - // Create partitions without explicitly specifying database - sessionCatalog.setCurrentDatabase("mydb") - sessionCatalog.createPartitions( - TableIdentifier("tbl"), Seq(partWithMixedOrder), ignoreIfExists = false) - assert(catalogPartitionsEqual( - externalCatalog.listPartitions("mydb", "tbl"), part1, part2, partWithMixedOrder)) + withEmptyCatalog { catalog => + catalog.createDatabase(newDb("mydb"), ignoreIfExists = false) + catalog.createTable(newTable("tbl", "mydb"), ignoreIfExists = false) + catalog.createPartitions( + TableIdentifier("tbl", Some("mydb")), Seq(part1, part2), ignoreIfExists = false) + assert(catalogPartitionsEqual( + catalog.externalCatalog.listPartitions("mydb", "tbl"), part1, part2)) + // Create partitions without explicitly specifying database + catalog.setCurrentDatabase("mydb") + catalog.createPartitions( + TableIdentifier("tbl"), Seq(partWithMixedOrder), ignoreIfExists = false) + assert(catalogPartitionsEqual( + catalog.externalCatalog.listPartitions("mydb", "tbl"), part1, part2, partWithMixedOrder)) + } } test("create partitions when database/table does not exist") { - val catalog = new SessionCatalog(newBasicCatalog()) - intercept[NoSuchDatabaseException] { - catalog.createPartitions( - TableIdentifier("tbl1", Some("unknown_db")), Seq(), ignoreIfExists = false) - } - intercept[NoSuchTableException] { - catalog.createPartitions( - TableIdentifier("does_not_exist", Some("db2")), Seq(), ignoreIfExists = false) + withBasicCatalog { catalog => + intercept[NoSuchDatabaseException] { + catalog.createPartitions( + TableIdentifier("tbl1", Some("unknown_db")), Seq(), ignoreIfExists = false) + } + intercept[NoSuchTableException] { + catalog.createPartitions( + TableIdentifier("does_not_exist", Some("db2")), Seq(), ignoreIfExists = false) + } } } test("create partitions that already exist") { - val catalog = new SessionCatalog(newBasicCatalog()) - intercept[AnalysisException] { + withBasicCatalog { catalog => + intercept[AnalysisException] { + catalog.createPartitions( + TableIdentifier("tbl2", Some("db2")), Seq(part1), ignoreIfExists = false) + } catalog.createPartitions( - TableIdentifier("tbl2", Some("db2")), Seq(part1), ignoreIfExists = false) + TableIdentifier("tbl2", Some("db2")), Seq(part1), ignoreIfExists = true) } - catalog.createPartitions( - TableIdentifier("tbl2", Some("db2")), Seq(part1), ignoreIfExists = true) } test("create partitions with invalid part spec") { - val catalog = new SessionCatalog(newBasicCatalog()) - var e = intercept[AnalysisException] { - catalog.createPartitions( - TableIdentifier("tbl2", Some("db2")), - Seq(part1, partWithLessColumns), ignoreIfExists = false) - } - assert(e.getMessage.contains("Partition spec is invalid. The spec (a) must match " + - "the partition spec (a, b) defined in table '`db2`.`tbl2`'")) - e = intercept[AnalysisException] { - catalog.createPartitions( - TableIdentifier("tbl2", Some("db2")), - Seq(part1, partWithMoreColumns), ignoreIfExists = true) - } - assert(e.getMessage.contains("Partition spec is invalid. The spec (a, b, c) must match " + - "the partition spec (a, b) defined in table '`db2`.`tbl2`'")) - e = intercept[AnalysisException] { - catalog.createPartitions( - TableIdentifier("tbl2", Some("db2")), - Seq(partWithUnknownColumns, part1), ignoreIfExists = true) - } - assert(e.getMessage.contains("Partition spec is invalid. The spec (a, unknown) must match " + - "the partition spec (a, b) defined in table '`db2`.`tbl2`'")) - e = intercept[AnalysisException] { - catalog.createPartitions( - TableIdentifier("tbl2", Some("db2")), - Seq(partWithEmptyValue, part1), ignoreIfExists = true) + withBasicCatalog { catalog => + var e = intercept[AnalysisException] { + catalog.createPartitions( + TableIdentifier("tbl2", Some("db2")), + Seq(part1, partWithLessColumns), ignoreIfExists = false) + } + assert(e.getMessage.contains("Partition spec is invalid. The spec (a) must match " + + "the partition spec (a, b) defined in table '`db2`.`tbl2`'")) + e = intercept[AnalysisException] { + catalog.createPartitions( + TableIdentifier("tbl2", Some("db2")), + Seq(part1, partWithMoreColumns), ignoreIfExists = true) + } + assert(e.getMessage.contains("Partition spec is invalid. The spec (a, b, c) must match " + + "the partition spec (a, b) defined in table '`db2`.`tbl2`'")) + e = intercept[AnalysisException] { + catalog.createPartitions( + TableIdentifier("tbl2", Some("db2")), + Seq(partWithUnknownColumns, part1), ignoreIfExists = true) + } + assert(e.getMessage.contains("Partition spec is invalid. The spec (a, unknown) must match " + + "the partition spec (a, b) defined in table '`db2`.`tbl2`'")) + e = intercept[AnalysisException] { + catalog.createPartitions( + TableIdentifier("tbl2", Some("db2")), + Seq(partWithEmptyValue, part1), ignoreIfExists = true) + } + assert(e.getMessage.contains("Partition spec is invalid. The spec ([a=3, b=]) contains an " + + "empty partition column value")) } - assert(e.getMessage.contains("Partition spec is invalid. The spec ([a=3, b=]) contains an " + - "empty partition column value")) } test("drop partitions") { - val externalCatalog = newBasicCatalog() - val sessionCatalog = new SessionCatalog(externalCatalog) - assert(catalogPartitionsEqual(externalCatalog.listPartitions("db2", "tbl2"), part1, part2)) - sessionCatalog.dropPartitions( - TableIdentifier("tbl2", Some("db2")), - Seq(part1.spec), - ignoreIfNotExists = false, - purge = false, - retainData = false) - assert(catalogPartitionsEqual(externalCatalog.listPartitions("db2", "tbl2"), part2)) - // Drop partitions without explicitly specifying database - sessionCatalog.setCurrentDatabase("db2") - sessionCatalog.dropPartitions( - TableIdentifier("tbl2"), - Seq(part2.spec), - ignoreIfNotExists = false, - purge = false, - retainData = false) - assert(externalCatalog.listPartitions("db2", "tbl2").isEmpty) - // Drop multiple partitions at once - sessionCatalog.createPartitions( - TableIdentifier("tbl2", Some("db2")), Seq(part1, part2), ignoreIfExists = false) - assert(catalogPartitionsEqual(externalCatalog.listPartitions("db2", "tbl2"), part1, part2)) - sessionCatalog.dropPartitions( - TableIdentifier("tbl2", Some("db2")), - Seq(part1.spec, part2.spec), - ignoreIfNotExists = false, - purge = false, - retainData = false) - assert(externalCatalog.listPartitions("db2", "tbl2").isEmpty) - } - - test("drop partitions when database/table does not exist") { - val catalog = new SessionCatalog(newBasicCatalog()) - intercept[NoSuchDatabaseException] { + withBasicCatalog { catalog => + assert(catalogPartitionsEqual( + catalog.externalCatalog.listPartitions("db2", "tbl2"), part1, part2)) catalog.dropPartitions( - TableIdentifier("tbl1", Some("unknown_db")), - Seq(), + TableIdentifier("tbl2", Some("db2")), + Seq(part1.spec), ignoreIfNotExists = false, purge = false, retainData = false) - } - intercept[NoSuchTableException] { + assert(catalogPartitionsEqual( + catalog.externalCatalog.listPartitions("db2", "tbl2"), part2)) + // Drop partitions without explicitly specifying database + catalog.setCurrentDatabase("db2") catalog.dropPartitions( - TableIdentifier("does_not_exist", Some("db2")), - Seq(), + TableIdentifier("tbl2"), + Seq(part2.spec), ignoreIfNotExists = false, purge = false, retainData = false) - } - } - - test("drop partitions that do not exist") { - val catalog = new SessionCatalog(newBasicCatalog()) - intercept[AnalysisException] { + assert(catalog.externalCatalog.listPartitions("db2", "tbl2").isEmpty) + // Drop multiple partitions at once + catalog.createPartitions( + TableIdentifier("tbl2", Some("db2")), Seq(part1, part2), ignoreIfExists = false) + assert(catalogPartitionsEqual( + catalog.externalCatalog.listPartitions("db2", "tbl2"), part1, part2)) catalog.dropPartitions( TableIdentifier("tbl2", Some("db2")), - Seq(part3.spec), + Seq(part1.spec, part2.spec), ignoreIfNotExists = false, purge = false, retainData = false) + assert(catalog.externalCatalog.listPartitions("db2", "tbl2").isEmpty) } - catalog.dropPartitions( - TableIdentifier("tbl2", Some("db2")), - Seq(part3.spec), - ignoreIfNotExists = true, - purge = false, - retainData = false) } - test("drop partitions with invalid partition spec") { - val catalog = new SessionCatalog(newBasicCatalog()) - var e = intercept[AnalysisException] { - catalog.dropPartitions( - TableIdentifier("tbl2", Some("db2")), - Seq(partWithMoreColumns.spec), - ignoreIfNotExists = false, - purge = false, - retainData = false) + test("drop partitions when database/table does not exist") { + withBasicCatalog { catalog => + intercept[NoSuchDatabaseException] { + catalog.dropPartitions( + TableIdentifier("tbl1", Some("unknown_db")), + Seq(), + ignoreIfNotExists = false, + purge = false, + retainData = false) + } + intercept[NoSuchTableException] { + catalog.dropPartitions( + TableIdentifier("does_not_exist", Some("db2")), + Seq(), + ignoreIfNotExists = false, + purge = false, + retainData = false) + } } - assert(e.getMessage.contains( - "Partition spec is invalid. The spec (a, b, c) must be contained within " + - "the partition spec (a, b) defined in table '`db2`.`tbl2`'")) - e = intercept[AnalysisException] { + } + + test("drop partitions that do not exist") { + withBasicCatalog { catalog => + intercept[AnalysisException] { + catalog.dropPartitions( + TableIdentifier("tbl2", Some("db2")), + Seq(part3.spec), + ignoreIfNotExists = false, + purge = false, + retainData = false) + } catalog.dropPartitions( TableIdentifier("tbl2", Some("db2")), - Seq(partWithUnknownColumns.spec), - ignoreIfNotExists = false, + Seq(part3.spec), + ignoreIfNotExists = true, purge = false, retainData = false) } - assert(e.getMessage.contains( - "Partition spec is invalid. The spec (a, unknown) must be contained within " + - "the partition spec (a, b) defined in table '`db2`.`tbl2`'")) - e = intercept[AnalysisException] { - catalog.dropPartitions( - TableIdentifier("tbl2", Some("db2")), - Seq(partWithEmptyValue.spec, part1.spec), - ignoreIfNotExists = false, - purge = false, - retainData = false) + } + + test("drop partitions with invalid partition spec") { + withBasicCatalog { catalog => + var e = intercept[AnalysisException] { + catalog.dropPartitions( + TableIdentifier("tbl2", Some("db2")), + Seq(partWithMoreColumns.spec), + ignoreIfNotExists = false, + purge = false, + retainData = false) + } + assert(e.getMessage.contains( + "Partition spec is invalid. The spec (a, b, c) must be contained within " + + "the partition spec (a, b) defined in table '`db2`.`tbl2`'")) + e = intercept[AnalysisException] { + catalog.dropPartitions( + TableIdentifier("tbl2", Some("db2")), + Seq(partWithUnknownColumns.spec), + ignoreIfNotExists = false, + purge = false, + retainData = false) + } + assert(e.getMessage.contains( + "Partition spec is invalid. The spec (a, unknown) must be contained within " + + "the partition spec (a, b) defined in table '`db2`.`tbl2`'")) + e = intercept[AnalysisException] { + catalog.dropPartitions( + TableIdentifier("tbl2", Some("db2")), + Seq(partWithEmptyValue.spec, part1.spec), + ignoreIfNotExists = false, + purge = false, + retainData = false) + } + assert(e.getMessage.contains("Partition spec is invalid. The spec ([a=3, b=]) contains an " + + "empty partition column value")) } - assert(e.getMessage.contains("Partition spec is invalid. The spec ([a=3, b=]) contains an " + - "empty partition column value")) } test("get partition") { - val catalog = new SessionCatalog(newBasicCatalog()) - assert(catalog.getPartition( - TableIdentifier("tbl2", Some("db2")), part1.spec).spec == part1.spec) - assert(catalog.getPartition( - TableIdentifier("tbl2", Some("db2")), part2.spec).spec == part2.spec) - // Get partition without explicitly specifying database - catalog.setCurrentDatabase("db2") - assert(catalog.getPartition(TableIdentifier("tbl2"), part1.spec).spec == part1.spec) - assert(catalog.getPartition(TableIdentifier("tbl2"), part2.spec).spec == part2.spec) - // Get non-existent partition - intercept[AnalysisException] { - catalog.getPartition(TableIdentifier("tbl2"), part3.spec) + withBasicCatalog { catalog => + assert(catalog.getPartition( + TableIdentifier("tbl2", Some("db2")), part1.spec).spec == part1.spec) + assert(catalog.getPartition( + TableIdentifier("tbl2", Some("db2")), part2.spec).spec == part2.spec) + // Get partition without explicitly specifying database + catalog.setCurrentDatabase("db2") + assert(catalog.getPartition(TableIdentifier("tbl2"), part1.spec).spec == part1.spec) + assert(catalog.getPartition(TableIdentifier("tbl2"), part2.spec).spec == part2.spec) + // Get non-existent partition + intercept[AnalysisException] { + catalog.getPartition(TableIdentifier("tbl2"), part3.spec) + } } } test("get partition when database/table does not exist") { - val catalog = new SessionCatalog(newBasicCatalog()) - intercept[NoSuchDatabaseException] { - catalog.getPartition(TableIdentifier("tbl1", Some("unknown_db")), part1.spec) - } - intercept[NoSuchTableException] { - catalog.getPartition(TableIdentifier("does_not_exist", Some("db2")), part1.spec) + withBasicCatalog { catalog => + intercept[NoSuchDatabaseException] { + catalog.getPartition(TableIdentifier("tbl1", Some("unknown_db")), part1.spec) + } + intercept[NoSuchTableException] { + catalog.getPartition(TableIdentifier("does_not_exist", Some("db2")), part1.spec) + } } } test("get partition with invalid partition spec") { - val catalog = new SessionCatalog(newBasicCatalog()) - var e = intercept[AnalysisException] { - catalog.getPartition(TableIdentifier("tbl1", Some("db2")), partWithLessColumns.spec) - } - assert(e.getMessage.contains("Partition spec is invalid. The spec (a) must match " + - "the partition spec (a, b) defined in table '`db2`.`tbl1`'")) - e = intercept[AnalysisException] { - catalog.getPartition(TableIdentifier("tbl1", Some("db2")), partWithMoreColumns.spec) - } - assert(e.getMessage.contains("Partition spec is invalid. The spec (a, b, c) must match " + - "the partition spec (a, b) defined in table '`db2`.`tbl1`'")) - e = intercept[AnalysisException] { - catalog.getPartition(TableIdentifier("tbl1", Some("db2")), partWithUnknownColumns.spec) - } - assert(e.getMessage.contains("Partition spec is invalid. The spec (a, unknown) must match " + - "the partition spec (a, b) defined in table '`db2`.`tbl1`'")) - e = intercept[AnalysisException] { - catalog.getPartition(TableIdentifier("tbl1", Some("db2")), partWithEmptyValue.spec) + withBasicCatalog { catalog => + var e = intercept[AnalysisException] { + catalog.getPartition(TableIdentifier("tbl1", Some("db2")), partWithLessColumns.spec) + } + assert(e.getMessage.contains("Partition spec is invalid. The spec (a) must match " + + "the partition spec (a, b) defined in table '`db2`.`tbl1`'")) + e = intercept[AnalysisException] { + catalog.getPartition(TableIdentifier("tbl1", Some("db2")), partWithMoreColumns.spec) + } + assert(e.getMessage.contains("Partition spec is invalid. The spec (a, b, c) must match " + + "the partition spec (a, b) defined in table '`db2`.`tbl1`'")) + e = intercept[AnalysisException] { + catalog.getPartition(TableIdentifier("tbl1", Some("db2")), partWithUnknownColumns.spec) + } + assert(e.getMessage.contains("Partition spec is invalid. The spec (a, unknown) must match " + + "the partition spec (a, b) defined in table '`db2`.`tbl1`'")) + e = intercept[AnalysisException] { + catalog.getPartition(TableIdentifier("tbl1", Some("db2")), partWithEmptyValue.spec) + } + assert(e.getMessage.contains("Partition spec is invalid. The spec ([a=3, b=]) contains an " + + "empty partition column value")) } - assert(e.getMessage.contains("Partition spec is invalid. The spec ([a=3, b=]) contains an " + - "empty partition column value")) } test("rename partitions") { - val catalog = new SessionCatalog(newBasicCatalog()) - val newPart1 = part1.copy(spec = Map("a" -> "100", "b" -> "101")) - val newPart2 = part2.copy(spec = Map("a" -> "200", "b" -> "201")) - val newSpecs = Seq(newPart1.spec, newPart2.spec) - catalog.renamePartitions( - TableIdentifier("tbl2", Some("db2")), Seq(part1.spec, part2.spec), newSpecs) - assert(catalog.getPartition( - TableIdentifier("tbl2", Some("db2")), newPart1.spec).spec === newPart1.spec) - assert(catalog.getPartition( - TableIdentifier("tbl2", Some("db2")), newPart2.spec).spec === newPart2.spec) - intercept[AnalysisException] { - catalog.getPartition(TableIdentifier("tbl2", Some("db2")), part1.spec) - } - intercept[AnalysisException] { - catalog.getPartition(TableIdentifier("tbl2", Some("db2")), part2.spec) - } - // Rename partitions without explicitly specifying database - catalog.setCurrentDatabase("db2") - catalog.renamePartitions(TableIdentifier("tbl2"), newSpecs, Seq(part1.spec, part2.spec)) - assert(catalog.getPartition(TableIdentifier("tbl2"), part1.spec).spec === part1.spec) - assert(catalog.getPartition(TableIdentifier("tbl2"), part2.spec).spec === part2.spec) - intercept[AnalysisException] { - catalog.getPartition(TableIdentifier("tbl2"), newPart1.spec) - } - intercept[AnalysisException] { - catalog.getPartition(TableIdentifier("tbl2"), newPart2.spec) + withBasicCatalog { catalog => + val newPart1 = part1.copy(spec = Map("a" -> "100", "b" -> "101")) + val newPart2 = part2.copy(spec = Map("a" -> "200", "b" -> "201")) + val newSpecs = Seq(newPart1.spec, newPart2.spec) + catalog.renamePartitions( + TableIdentifier("tbl2", Some("db2")), Seq(part1.spec, part2.spec), newSpecs) + assert(catalog.getPartition( + TableIdentifier("tbl2", Some("db2")), newPart1.spec).spec === newPart1.spec) + assert(catalog.getPartition( + TableIdentifier("tbl2", Some("db2")), newPart2.spec).spec === newPart2.spec) + intercept[AnalysisException] { + catalog.getPartition(TableIdentifier("tbl2", Some("db2")), part1.spec) + } + intercept[AnalysisException] { + catalog.getPartition(TableIdentifier("tbl2", Some("db2")), part2.spec) + } + // Rename partitions without explicitly specifying database + catalog.setCurrentDatabase("db2") + catalog.renamePartitions(TableIdentifier("tbl2"), newSpecs, Seq(part1.spec, part2.spec)) + assert(catalog.getPartition(TableIdentifier("tbl2"), part1.spec).spec === part1.spec) + assert(catalog.getPartition(TableIdentifier("tbl2"), part2.spec).spec === part2.spec) + intercept[AnalysisException] { + catalog.getPartition(TableIdentifier("tbl2"), newPart1.spec) + } + intercept[AnalysisException] { + catalog.getPartition(TableIdentifier("tbl2"), newPart2.spec) + } } } test("rename partitions when database/table does not exist") { - val catalog = new SessionCatalog(newBasicCatalog()) - intercept[NoSuchDatabaseException] { - catalog.renamePartitions( - TableIdentifier("tbl1", Some("unknown_db")), Seq(part1.spec), Seq(part2.spec)) - } - intercept[NoSuchTableException] { - catalog.renamePartitions( - TableIdentifier("does_not_exist", Some("db2")), Seq(part1.spec), Seq(part2.spec)) + withBasicCatalog { catalog => + intercept[NoSuchDatabaseException] { + catalog.renamePartitions( + TableIdentifier("tbl1", Some("unknown_db")), Seq(part1.spec), Seq(part2.spec)) + } + intercept[NoSuchTableException] { + catalog.renamePartitions( + TableIdentifier("does_not_exist", Some("db2")), Seq(part1.spec), Seq(part2.spec)) + } } } test("rename partition with invalid partition spec") { - val catalog = new SessionCatalog(newBasicCatalog()) - var e = intercept[AnalysisException] { - catalog.renamePartitions( - TableIdentifier("tbl1", Some("db2")), - Seq(part1.spec), Seq(partWithLessColumns.spec)) - } - assert(e.getMessage.contains("Partition spec is invalid. The spec (a) must match " + - "the partition spec (a, b) defined in table '`db2`.`tbl1`'")) - e = intercept[AnalysisException] { - catalog.renamePartitions( - TableIdentifier("tbl1", Some("db2")), - Seq(part1.spec), Seq(partWithMoreColumns.spec)) - } - assert(e.getMessage.contains("Partition spec is invalid. The spec (a, b, c) must match " + - "the partition spec (a, b) defined in table '`db2`.`tbl1`'")) - e = intercept[AnalysisException] { - catalog.renamePartitions( - TableIdentifier("tbl1", Some("db2")), - Seq(part1.spec), Seq(partWithUnknownColumns.spec)) - } - assert(e.getMessage.contains("Partition spec is invalid. The spec (a, unknown) must match " + - "the partition spec (a, b) defined in table '`db2`.`tbl1`'")) - e = intercept[AnalysisException] { - catalog.renamePartitions( - TableIdentifier("tbl1", Some("db2")), - Seq(part1.spec), Seq(partWithEmptyValue.spec)) + withBasicCatalog { catalog => + var e = intercept[AnalysisException] { + catalog.renamePartitions( + TableIdentifier("tbl1", Some("db2")), + Seq(part1.spec), Seq(partWithLessColumns.spec)) + } + assert(e.getMessage.contains("Partition spec is invalid. The spec (a) must match " + + "the partition spec (a, b) defined in table '`db2`.`tbl1`'")) + e = intercept[AnalysisException] { + catalog.renamePartitions( + TableIdentifier("tbl1", Some("db2")), + Seq(part1.spec), Seq(partWithMoreColumns.spec)) + } + assert(e.getMessage.contains("Partition spec is invalid. The spec (a, b, c) must match " + + "the partition spec (a, b) defined in table '`db2`.`tbl1`'")) + e = intercept[AnalysisException] { + catalog.renamePartitions( + TableIdentifier("tbl1", Some("db2")), + Seq(part1.spec), Seq(partWithUnknownColumns.spec)) + } + assert(e.getMessage.contains("Partition spec is invalid. The spec (a, unknown) must match " + + "the partition spec (a, b) defined in table '`db2`.`tbl1`'")) + e = intercept[AnalysisException] { + catalog.renamePartitions( + TableIdentifier("tbl1", Some("db2")), + Seq(part1.spec), Seq(partWithEmptyValue.spec)) + } + assert(e.getMessage.contains("Partition spec is invalid. The spec ([a=3, b=]) contains an " + + "empty partition column value")) } - assert(e.getMessage.contains("Partition spec is invalid. The spec ([a=3, b=]) contains an " + - "empty partition column value")) } test("alter partitions") { - val catalog = new SessionCatalog(newBasicCatalog()) - val newLocation = newUriForDatabase() - // Alter but keep spec the same - val oldPart1 = catalog.getPartition(TableIdentifier("tbl2", Some("db2")), part1.spec) - val oldPart2 = catalog.getPartition(TableIdentifier("tbl2", Some("db2")), part2.spec) - catalog.alterPartitions(TableIdentifier("tbl2", Some("db2")), Seq( - oldPart1.copy(storage = storageFormat.copy(locationUri = Some(newLocation))), - oldPart2.copy(storage = storageFormat.copy(locationUri = Some(newLocation))))) - val newPart1 = catalog.getPartition(TableIdentifier("tbl2", Some("db2")), part1.spec) - val newPart2 = catalog.getPartition(TableIdentifier("tbl2", Some("db2")), part2.spec) - assert(newPart1.storage.locationUri == Some(newLocation)) - assert(newPart2.storage.locationUri == Some(newLocation)) - assert(oldPart1.storage.locationUri != Some(newLocation)) - assert(oldPart2.storage.locationUri != Some(newLocation)) - // Alter partitions without explicitly specifying database - catalog.setCurrentDatabase("db2") - catalog.alterPartitions(TableIdentifier("tbl2"), Seq(oldPart1, oldPart2)) - val newerPart1 = catalog.getPartition(TableIdentifier("tbl2"), part1.spec) - val newerPart2 = catalog.getPartition(TableIdentifier("tbl2"), part2.spec) - assert(oldPart1.storage.locationUri == newerPart1.storage.locationUri) - assert(oldPart2.storage.locationUri == newerPart2.storage.locationUri) - // Alter but change spec, should fail because new partition specs do not exist yet - val badPart1 = part1.copy(spec = Map("a" -> "v1", "b" -> "v2")) - val badPart2 = part2.copy(spec = Map("a" -> "v3", "b" -> "v4")) - intercept[AnalysisException] { - catalog.alterPartitions(TableIdentifier("tbl2", Some("db2")), Seq(badPart1, badPart2)) + withBasicCatalog { catalog => + val newLocation = newUriForDatabase() + // Alter but keep spec the same + val oldPart1 = catalog.getPartition(TableIdentifier("tbl2", Some("db2")), part1.spec) + val oldPart2 = catalog.getPartition(TableIdentifier("tbl2", Some("db2")), part2.spec) + catalog.alterPartitions(TableIdentifier("tbl2", Some("db2")), Seq( + oldPart1.copy(storage = storageFormat.copy(locationUri = Some(newLocation))), + oldPart2.copy(storage = storageFormat.copy(locationUri = Some(newLocation))))) + val newPart1 = catalog.getPartition(TableIdentifier("tbl2", Some("db2")), part1.spec) + val newPart2 = catalog.getPartition(TableIdentifier("tbl2", Some("db2")), part2.spec) + assert(newPart1.storage.locationUri == Some(newLocation)) + assert(newPart2.storage.locationUri == Some(newLocation)) + assert(oldPart1.storage.locationUri != Some(newLocation)) + assert(oldPart2.storage.locationUri != Some(newLocation)) + // Alter partitions without explicitly specifying database + catalog.setCurrentDatabase("db2") + catalog.alterPartitions(TableIdentifier("tbl2"), Seq(oldPart1, oldPart2)) + val newerPart1 = catalog.getPartition(TableIdentifier("tbl2"), part1.spec) + val newerPart2 = catalog.getPartition(TableIdentifier("tbl2"), part2.spec) + assert(oldPart1.storage.locationUri == newerPart1.storage.locationUri) + assert(oldPart2.storage.locationUri == newerPart2.storage.locationUri) + // Alter but change spec, should fail because new partition specs do not exist yet + val badPart1 = part1.copy(spec = Map("a" -> "v1", "b" -> "v2")) + val badPart2 = part2.copy(spec = Map("a" -> "v3", "b" -> "v4")) + intercept[AnalysisException] { + catalog.alterPartitions(TableIdentifier("tbl2", Some("db2")), Seq(badPart1, badPart2)) + } } } test("alter partitions when database/table does not exist") { - val catalog = new SessionCatalog(newBasicCatalog()) - intercept[NoSuchDatabaseException] { - catalog.alterPartitions(TableIdentifier("tbl1", Some("unknown_db")), Seq(part1)) - } - intercept[NoSuchTableException] { - catalog.alterPartitions(TableIdentifier("does_not_exist", Some("db2")), Seq(part1)) + withBasicCatalog { catalog => + intercept[NoSuchDatabaseException] { + catalog.alterPartitions(TableIdentifier("tbl1", Some("unknown_db")), Seq(part1)) + } + intercept[NoSuchTableException] { + catalog.alterPartitions(TableIdentifier("does_not_exist", Some("db2")), Seq(part1)) + } } } test("alter partition with invalid partition spec") { - val catalog = new SessionCatalog(newBasicCatalog()) - var e = intercept[AnalysisException] { - catalog.alterPartitions(TableIdentifier("tbl1", Some("db2")), Seq(partWithLessColumns)) - } - assert(e.getMessage.contains("Partition spec is invalid. The spec (a) must match " + - "the partition spec (a, b) defined in table '`db2`.`tbl1`'")) - e = intercept[AnalysisException] { - catalog.alterPartitions(TableIdentifier("tbl1", Some("db2")), Seq(partWithMoreColumns)) - } - assert(e.getMessage.contains("Partition spec is invalid. The spec (a, b, c) must match " + - "the partition spec (a, b) defined in table '`db2`.`tbl1`'")) - e = intercept[AnalysisException] { - catalog.alterPartitions(TableIdentifier("tbl1", Some("db2")), Seq(partWithUnknownColumns)) - } - assert(e.getMessage.contains("Partition spec is invalid. The spec (a, unknown) must match " + - "the partition spec (a, b) defined in table '`db2`.`tbl1`'")) - e = intercept[AnalysisException] { - catalog.alterPartitions(TableIdentifier("tbl1", Some("db2")), Seq(partWithEmptyValue)) + withBasicCatalog { catalog => + var e = intercept[AnalysisException] { + catalog.alterPartitions(TableIdentifier("tbl1", Some("db2")), Seq(partWithLessColumns)) + } + assert(e.getMessage.contains("Partition spec is invalid. The spec (a) must match " + + "the partition spec (a, b) defined in table '`db2`.`tbl1`'")) + e = intercept[AnalysisException] { + catalog.alterPartitions(TableIdentifier("tbl1", Some("db2")), Seq(partWithMoreColumns)) + } + assert(e.getMessage.contains("Partition spec is invalid. The spec (a, b, c) must match " + + "the partition spec (a, b) defined in table '`db2`.`tbl1`'")) + e = intercept[AnalysisException] { + catalog.alterPartitions(TableIdentifier("tbl1", Some("db2")), Seq(partWithUnknownColumns)) + } + assert(e.getMessage.contains("Partition spec is invalid. The spec (a, unknown) must match " + + "the partition spec (a, b) defined in table '`db2`.`tbl1`'")) + e = intercept[AnalysisException] { + catalog.alterPartitions(TableIdentifier("tbl1", Some("db2")), Seq(partWithEmptyValue)) + } + assert(e.getMessage.contains("Partition spec is invalid. The spec ([a=3, b=]) contains an " + + "empty partition column value")) } - assert(e.getMessage.contains("Partition spec is invalid. The spec ([a=3, b=]) contains an " + - "empty partition column value")) } test("list partition names") { - val catalog = new SessionCatalog(newBasicCatalog()) - val expectedPartitionNames = Seq("a=1/b=2", "a=3/b=4") - assert(catalog.listPartitionNames(TableIdentifier("tbl2", Some("db2"))) == - expectedPartitionNames) - // List partition names without explicitly specifying database - catalog.setCurrentDatabase("db2") - assert(catalog.listPartitionNames(TableIdentifier("tbl2")) == expectedPartitionNames) + withBasicCatalog { catalog => + val expectedPartitionNames = Seq("a=1/b=2", "a=3/b=4") + assert(catalog.listPartitionNames(TableIdentifier("tbl2", Some("db2"))) == + expectedPartitionNames) + // List partition names without explicitly specifying database + catalog.setCurrentDatabase("db2") + assert(catalog.listPartitionNames(TableIdentifier("tbl2")) == expectedPartitionNames) + } } test("list partition names with partial partition spec") { - val catalog = new SessionCatalog(newBasicCatalog()) - assert( - catalog.listPartitionNames(TableIdentifier("tbl2", Some("db2")), Some(Map("a" -> "1"))) == - Seq("a=1/b=2")) + withBasicCatalog { catalog => + assert( + catalog.listPartitionNames(TableIdentifier("tbl2", Some("db2")), Some(Map("a" -> "1"))) == + Seq("a=1/b=2")) + } } test("list partition names with invalid partial partition spec") { - val catalog = new SessionCatalog(newBasicCatalog()) - var e = intercept[AnalysisException] { - catalog.listPartitionNames(TableIdentifier("tbl2", Some("db2")), - Some(partWithMoreColumns.spec)) - } - assert(e.getMessage.contains("Partition spec is invalid. The spec (a, b, c) must be " + - "contained within the partition spec (a, b) defined in table '`db2`.`tbl2`'")) - e = intercept[AnalysisException] { - catalog.listPartitionNames(TableIdentifier("tbl2", Some("db2")), - Some(partWithUnknownColumns.spec)) - } - assert(e.getMessage.contains("Partition spec is invalid. The spec (a, unknown) must be " + - "contained within the partition spec (a, b) defined in table '`db2`.`tbl2`'")) - e = intercept[AnalysisException] { - catalog.listPartitionNames(TableIdentifier("tbl2", Some("db2")), - Some(partWithEmptyValue.spec)) + withBasicCatalog { catalog => + var e = intercept[AnalysisException] { + catalog.listPartitionNames(TableIdentifier("tbl2", Some("db2")), + Some(partWithMoreColumns.spec)) + } + assert(e.getMessage.contains("Partition spec is invalid. The spec (a, b, c) must be " + + "contained within the partition spec (a, b) defined in table '`db2`.`tbl2`'")) + e = intercept[AnalysisException] { + catalog.listPartitionNames(TableIdentifier("tbl2", Some("db2")), + Some(partWithUnknownColumns.spec)) + } + assert(e.getMessage.contains("Partition spec is invalid. The spec (a, unknown) must be " + + "contained within the partition spec (a, b) defined in table '`db2`.`tbl2`'")) + e = intercept[AnalysisException] { + catalog.listPartitionNames(TableIdentifier("tbl2", Some("db2")), + Some(partWithEmptyValue.spec)) + } + assert(e.getMessage.contains("Partition spec is invalid. The spec ([a=3, b=]) contains an " + + "empty partition column value")) } - assert(e.getMessage.contains("Partition spec is invalid. The spec ([a=3, b=]) contains an " + - "empty partition column value")) } test("list partitions") { - val catalog = new SessionCatalog(newBasicCatalog()) - assert(catalogPartitionsEqual( - catalog.listPartitions(TableIdentifier("tbl2", Some("db2"))), part1, part2)) - // List partitions without explicitly specifying database - catalog.setCurrentDatabase("db2") - assert(catalogPartitionsEqual(catalog.listPartitions(TableIdentifier("tbl2")), part1, part2)) + withBasicCatalog { catalog => + assert(catalogPartitionsEqual( + catalog.listPartitions(TableIdentifier("tbl2", Some("db2"))), part1, part2)) + // List partitions without explicitly specifying database + catalog.setCurrentDatabase("db2") + assert(catalogPartitionsEqual(catalog.listPartitions(TableIdentifier("tbl2")), part1, part2)) + } } test("list partitions with partial partition spec") { - val catalog = new SessionCatalog(newBasicCatalog()) - assert(catalogPartitionsEqual( - catalog.listPartitions(TableIdentifier("tbl2", Some("db2")), Some(Map("a" -> "1"))), part1)) + withBasicCatalog { catalog => + assert(catalogPartitionsEqual( + catalog.listPartitions(TableIdentifier("tbl2", Some("db2")), Some(Map("a" -> "1"))), part1)) + } } test("list partitions with invalid partial partition spec") { - val catalog = new SessionCatalog(newBasicCatalog()) - var e = intercept[AnalysisException] { - catalog.listPartitions(TableIdentifier("tbl2", Some("db2")), Some(partWithMoreColumns.spec)) - } - assert(e.getMessage.contains("Partition spec is invalid. The spec (a, b, c) must be " + - "contained within the partition spec (a, b) defined in table '`db2`.`tbl2`'")) - e = intercept[AnalysisException] { - catalog.listPartitions(TableIdentifier("tbl2", Some("db2")), - Some(partWithUnknownColumns.spec)) - } - assert(e.getMessage.contains("Partition spec is invalid. The spec (a, unknown) must be " + - "contained within the partition spec (a, b) defined in table '`db2`.`tbl2`'")) - e = intercept[AnalysisException] { - catalog.listPartitions(TableIdentifier("tbl2", Some("db2")), Some(partWithEmptyValue.spec)) + withBasicCatalog { catalog => + var e = intercept[AnalysisException] { + catalog.listPartitions(TableIdentifier("tbl2", Some("db2")), Some(partWithMoreColumns.spec)) + } + assert(e.getMessage.contains("Partition spec is invalid. The spec (a, b, c) must be " + + "contained within the partition spec (a, b) defined in table '`db2`.`tbl2`'")) + e = intercept[AnalysisException] { + catalog.listPartitions(TableIdentifier("tbl2", Some("db2")), + Some(partWithUnknownColumns.spec)) + } + assert(e.getMessage.contains("Partition spec is invalid. The spec (a, unknown) must be " + + "contained within the partition spec (a, b) defined in table '`db2`.`tbl2`'")) + e = intercept[AnalysisException] { + catalog.listPartitions(TableIdentifier("tbl2", Some("db2")), Some(partWithEmptyValue.spec)) + } + assert(e.getMessage.contains("Partition spec is invalid. The spec ([a=3, b=]) contains an " + + "empty partition column value")) } - assert(e.getMessage.contains("Partition spec is invalid. The spec ([a=3, b=]) contains an " + - "empty partition column value")) } test("list partitions when database/table does not exist") { - val catalog = new SessionCatalog(newBasicCatalog()) - intercept[NoSuchDatabaseException] { - catalog.listPartitions(TableIdentifier("tbl1", Some("unknown_db"))) - } - intercept[NoSuchTableException] { - catalog.listPartitions(TableIdentifier("does_not_exist", Some("db2"))) + withBasicCatalog { catalog => + intercept[NoSuchDatabaseException] { + catalog.listPartitions(TableIdentifier("tbl1", Some("unknown_db"))) + } + intercept[NoSuchTableException] { + catalog.listPartitions(TableIdentifier("does_not_exist", Some("db2"))) + } } } @@ -999,8 +1085,17 @@ class SessionCatalogSuite extends PlanTest { expectedParts: CatalogTablePartition*): Boolean = { // ExternalCatalog may set a default location for partitions, here we ignore the partition // location when comparing them. - actualParts.map(p => p.copy(storage = p.storage.copy(locationUri = None))).toSet == - expectedParts.map(p => p.copy(storage = p.storage.copy(locationUri = None))).toSet + // And for hive serde table, hive metastore will set some values(e.g.transient_lastDdlTime) + // in table's parameters and storage's properties, here we also ignore them. + val actualPartsNormalize = actualParts.map(p => + p.copy(parameters = Map.empty, storage = p.storage.copy( + properties = Map.empty, locationUri = None, serde = None))).toSet + + val expectedPartsNormalize = expectedParts.map(p => + p.copy(parameters = Map.empty, storage = p.storage.copy( + properties = Map.empty, locationUri = None, serde = None))).toSet + + actualPartsNormalize == expectedPartsNormalize } // -------------------------------------------------------------------------- @@ -1008,248 +1103,258 @@ class SessionCatalogSuite extends PlanTest { // -------------------------------------------------------------------------- test("basic create and list functions") { - val externalCatalog = newEmptyCatalog() - val sessionCatalog = new SessionCatalog(externalCatalog) - sessionCatalog.createDatabase(newDb("mydb"), ignoreIfExists = false) - sessionCatalog.createFunction(newFunc("myfunc", Some("mydb")), ignoreIfExists = false) - assert(externalCatalog.listFunctions("mydb", "*").toSet == Set("myfunc")) - // Create function without explicitly specifying database - sessionCatalog.setCurrentDatabase("mydb") - sessionCatalog.createFunction(newFunc("myfunc2"), ignoreIfExists = false) - assert(externalCatalog.listFunctions("mydb", "*").toSet == Set("myfunc", "myfunc2")) + withEmptyCatalog { catalog => + catalog.createDatabase(newDb("mydb"), ignoreIfExists = false) + catalog.createFunction(newFunc("myfunc", Some("mydb")), ignoreIfExists = false) + assert(catalog.externalCatalog.listFunctions("mydb", "*").toSet == Set("myfunc")) + // Create function without explicitly specifying database + catalog.setCurrentDatabase("mydb") + catalog.createFunction(newFunc("myfunc2"), ignoreIfExists = false) + assert(catalog.externalCatalog.listFunctions("mydb", "*").toSet == Set("myfunc", "myfunc2")) + } } test("create function when database does not exist") { - val catalog = new SessionCatalog(newBasicCatalog()) - intercept[NoSuchDatabaseException] { - catalog.createFunction( - newFunc("func5", Some("does_not_exist")), ignoreIfExists = false) + withBasicCatalog { catalog => + intercept[NoSuchDatabaseException] { + catalog.createFunction( + newFunc("func5", Some("does_not_exist")), ignoreIfExists = false) + } } } test("create function that already exists") { - val catalog = new SessionCatalog(newBasicCatalog()) - intercept[FunctionAlreadyExistsException] { - catalog.createFunction(newFunc("func1", Some("db2")), ignoreIfExists = false) + withBasicCatalog { catalog => + intercept[FunctionAlreadyExistsException] { + catalog.createFunction(newFunc("func1", Some("db2")), ignoreIfExists = false) + } + catalog.createFunction(newFunc("func1", Some("db2")), ignoreIfExists = true) } - catalog.createFunction(newFunc("func1", Some("db2")), ignoreIfExists = true) } test("create temp function") { - val catalog = new SessionCatalog(newBasicCatalog()) - val tempFunc1 = (e: Seq[Expression]) => e.head - val tempFunc2 = (e: Seq[Expression]) => e.last - val info1 = new ExpressionInfo("tempFunc1", "temp1") - val info2 = new ExpressionInfo("tempFunc2", "temp2") - catalog.createTempFunction("temp1", info1, tempFunc1, ignoreIfExists = false) - catalog.createTempFunction("temp2", info2, tempFunc2, ignoreIfExists = false) - val arguments = Seq(Literal(1), Literal(2), Literal(3)) - assert(catalog.lookupFunction(FunctionIdentifier("temp1"), arguments) === Literal(1)) - assert(catalog.lookupFunction(FunctionIdentifier("temp2"), arguments) === Literal(3)) - // Temporary function does not exist. - intercept[NoSuchFunctionException] { - catalog.lookupFunction(FunctionIdentifier("temp3"), arguments) - } - val tempFunc3 = (e: Seq[Expression]) => Literal(e.size) - val info3 = new ExpressionInfo("tempFunc3", "temp1") - // Temporary function already exists - intercept[TempFunctionAlreadyExistsException] { - catalog.createTempFunction("temp1", info3, tempFunc3, ignoreIfExists = false) - } - // Temporary function is overridden - catalog.createTempFunction("temp1", info3, tempFunc3, ignoreIfExists = true) - assert( - catalog.lookupFunction(FunctionIdentifier("temp1"), arguments) === Literal(arguments.length)) + withBasicCatalog { catalog => + val tempFunc1 = (e: Seq[Expression]) => e.head + val tempFunc2 = (e: Seq[Expression]) => e.last + val info1 = new ExpressionInfo("tempFunc1", "temp1") + val info2 = new ExpressionInfo("tempFunc2", "temp2") + catalog.createTempFunction("temp1", info1, tempFunc1, ignoreIfExists = false) + catalog.createTempFunction("temp2", info2, tempFunc2, ignoreIfExists = false) + val arguments = Seq(Literal(1), Literal(2), Literal(3)) + assert(catalog.lookupFunction(FunctionIdentifier("temp1"), arguments) === Literal(1)) + assert(catalog.lookupFunction(FunctionIdentifier("temp2"), arguments) === Literal(3)) + // Temporary function does not exist. + intercept[NoSuchFunctionException] { + catalog.lookupFunction(FunctionIdentifier("temp3"), arguments) + } + val tempFunc3 = (e: Seq[Expression]) => Literal(e.size) + val info3 = new ExpressionInfo("tempFunc3", "temp1") + // Temporary function already exists + intercept[TempFunctionAlreadyExistsException] { + catalog.createTempFunction("temp1", info3, tempFunc3, ignoreIfExists = false) + } + // Temporary function is overridden + catalog.createTempFunction("temp1", info3, tempFunc3, ignoreIfExists = true) + assert( + catalog.lookupFunction( + FunctionIdentifier("temp1"), arguments) === Literal(arguments.length)) + } } test("isTemporaryFunction") { - val externalCatalog = newBasicCatalog() - val sessionCatalog = new SessionCatalog(externalCatalog) - - // Returns false when the function does not exist - assert(!sessionCatalog.isTemporaryFunction(FunctionIdentifier("temp1"))) + withBasicCatalog { catalog => + // Returns false when the function does not exist + assert(!catalog.isTemporaryFunction(FunctionIdentifier("temp1"))) - val tempFunc1 = (e: Seq[Expression]) => e.head - val info1 = new ExpressionInfo("tempFunc1", "temp1") - sessionCatalog.createTempFunction("temp1", info1, tempFunc1, ignoreIfExists = false) + val tempFunc1 = (e: Seq[Expression]) => e.head + val info1 = new ExpressionInfo("tempFunc1", "temp1") + catalog.createTempFunction("temp1", info1, tempFunc1, ignoreIfExists = false) - // Returns true when the function is temporary - assert(sessionCatalog.isTemporaryFunction(FunctionIdentifier("temp1"))) + // Returns true when the function is temporary + assert(catalog.isTemporaryFunction(FunctionIdentifier("temp1"))) - // Returns false when the function is permanent - assert(externalCatalog.listFunctions("db2", "*").toSet == Set("func1")) - assert(!sessionCatalog.isTemporaryFunction(FunctionIdentifier("func1", Some("db2")))) - assert(!sessionCatalog.isTemporaryFunction(FunctionIdentifier("db2.func1"))) - sessionCatalog.setCurrentDatabase("db2") - assert(!sessionCatalog.isTemporaryFunction(FunctionIdentifier("func1"))) + // Returns false when the function is permanent + assert(catalog.externalCatalog.listFunctions("db2", "*").toSet == Set("func1")) + assert(!catalog.isTemporaryFunction(FunctionIdentifier("func1", Some("db2")))) + assert(!catalog.isTemporaryFunction(FunctionIdentifier("db2.func1"))) + catalog.setCurrentDatabase("db2") + assert(!catalog.isTemporaryFunction(FunctionIdentifier("func1"))) - // Returns false when the function is built-in or hive - assert(FunctionRegistry.builtin.functionExists("sum")) - assert(!sessionCatalog.isTemporaryFunction(FunctionIdentifier("sum"))) - assert(!sessionCatalog.isTemporaryFunction(FunctionIdentifier("histogram_numeric"))) + // Returns false when the function is built-in or hive + assert(FunctionRegistry.builtin.functionExists("sum")) + assert(!catalog.isTemporaryFunction(FunctionIdentifier("sum"))) + assert(!catalog.isTemporaryFunction(FunctionIdentifier("histogram_numeric"))) + } } test("drop function") { - val externalCatalog = newBasicCatalog() - val sessionCatalog = new SessionCatalog(externalCatalog) - assert(externalCatalog.listFunctions("db2", "*").toSet == Set("func1")) - sessionCatalog.dropFunction( - FunctionIdentifier("func1", Some("db2")), ignoreIfNotExists = false) - assert(externalCatalog.listFunctions("db2", "*").isEmpty) - // Drop function without explicitly specifying database - sessionCatalog.setCurrentDatabase("db2") - sessionCatalog.createFunction(newFunc("func2", Some("db2")), ignoreIfExists = false) - assert(externalCatalog.listFunctions("db2", "*").toSet == Set("func2")) - sessionCatalog.dropFunction(FunctionIdentifier("func2"), ignoreIfNotExists = false) - assert(externalCatalog.listFunctions("db2", "*").isEmpty) + withBasicCatalog { catalog => + assert(catalog.externalCatalog.listFunctions("db2", "*").toSet == Set("func1")) + catalog.dropFunction( + FunctionIdentifier("func1", Some("db2")), ignoreIfNotExists = false) + assert(catalog.externalCatalog.listFunctions("db2", "*").isEmpty) + // Drop function without explicitly specifying database + catalog.setCurrentDatabase("db2") + catalog.createFunction(newFunc("func2", Some("db2")), ignoreIfExists = false) + assert(catalog.externalCatalog.listFunctions("db2", "*").toSet == Set("func2")) + catalog.dropFunction(FunctionIdentifier("func2"), ignoreIfNotExists = false) + assert(catalog.externalCatalog.listFunctions("db2", "*").isEmpty) + } } test("drop function when database/function does not exist") { - val catalog = new SessionCatalog(newBasicCatalog()) - intercept[NoSuchDatabaseException] { - catalog.dropFunction( - FunctionIdentifier("something", Some("unknown_db")), ignoreIfNotExists = false) - } - intercept[NoSuchFunctionException] { - catalog.dropFunction(FunctionIdentifier("does_not_exist"), ignoreIfNotExists = false) + withBasicCatalog { catalog => + intercept[NoSuchDatabaseException] { + catalog.dropFunction( + FunctionIdentifier("something", Some("unknown_db")), ignoreIfNotExists = false) + } + intercept[NoSuchFunctionException] { + catalog.dropFunction(FunctionIdentifier("does_not_exist"), ignoreIfNotExists = false) + } + catalog.dropFunction(FunctionIdentifier("does_not_exist"), ignoreIfNotExists = true) } - catalog.dropFunction(FunctionIdentifier("does_not_exist"), ignoreIfNotExists = true) } test("drop temp function") { - val catalog = new SessionCatalog(newBasicCatalog()) - val info = new ExpressionInfo("tempFunc", "func1") - val tempFunc = (e: Seq[Expression]) => e.head - catalog.createTempFunction("func1", info, tempFunc, ignoreIfExists = false) - val arguments = Seq(Literal(1), Literal(2), Literal(3)) - assert(catalog.lookupFunction(FunctionIdentifier("func1"), arguments) === Literal(1)) - catalog.dropTempFunction("func1", ignoreIfNotExists = false) - intercept[NoSuchFunctionException] { - catalog.lookupFunction(FunctionIdentifier("func1"), arguments) - } - intercept[NoSuchTempFunctionException] { + withBasicCatalog { catalog => + val info = new ExpressionInfo("tempFunc", "func1") + val tempFunc = (e: Seq[Expression]) => e.head + catalog.createTempFunction("func1", info, tempFunc, ignoreIfExists = false) + val arguments = Seq(Literal(1), Literal(2), Literal(3)) + assert(catalog.lookupFunction(FunctionIdentifier("func1"), arguments) === Literal(1)) catalog.dropTempFunction("func1", ignoreIfNotExists = false) + intercept[NoSuchFunctionException] { + catalog.lookupFunction(FunctionIdentifier("func1"), arguments) + } + intercept[NoSuchTempFunctionException] { + catalog.dropTempFunction("func1", ignoreIfNotExists = false) + } + catalog.dropTempFunction("func1", ignoreIfNotExists = true) } - catalog.dropTempFunction("func1", ignoreIfNotExists = true) } test("get function") { - val catalog = new SessionCatalog(newBasicCatalog()) - val expected = - CatalogFunction(FunctionIdentifier("func1", Some("db2")), funcClass, - Seq.empty[FunctionResource]) - assert(catalog.getFunctionMetadata(FunctionIdentifier("func1", Some("db2"))) == expected) - // Get function without explicitly specifying database - catalog.setCurrentDatabase("db2") - assert(catalog.getFunctionMetadata(FunctionIdentifier("func1")) == expected) + withBasicCatalog { catalog => + val expected = + CatalogFunction(FunctionIdentifier("func1", Some("db2")), funcClass, + Seq.empty[FunctionResource]) + assert(catalog.getFunctionMetadata(FunctionIdentifier("func1", Some("db2"))) == expected) + // Get function without explicitly specifying database + catalog.setCurrentDatabase("db2") + assert(catalog.getFunctionMetadata(FunctionIdentifier("func1")) == expected) + } } test("get function when database/function does not exist") { - val catalog = new SessionCatalog(newBasicCatalog()) - intercept[NoSuchDatabaseException] { - catalog.getFunctionMetadata(FunctionIdentifier("func1", Some("unknown_db"))) - } - intercept[NoSuchFunctionException] { - catalog.getFunctionMetadata(FunctionIdentifier("does_not_exist", Some("db2"))) + withBasicCatalog { catalog => + intercept[NoSuchDatabaseException] { + catalog.getFunctionMetadata(FunctionIdentifier("func1", Some("unknown_db"))) + } + intercept[NoSuchFunctionException] { + catalog.getFunctionMetadata(FunctionIdentifier("does_not_exist", Some("db2"))) + } } } test("lookup temp function") { - val catalog = new SessionCatalog(newBasicCatalog()) - val info1 = new ExpressionInfo("tempFunc1", "func1") - val tempFunc1 = (e: Seq[Expression]) => e.head - catalog.createTempFunction("func1", info1, tempFunc1, ignoreIfExists = false) - assert(catalog.lookupFunction( - FunctionIdentifier("func1"), Seq(Literal(1), Literal(2), Literal(3))) == Literal(1)) - catalog.dropTempFunction("func1", ignoreIfNotExists = false) - intercept[NoSuchFunctionException] { - catalog.lookupFunction(FunctionIdentifier("func1"), Seq(Literal(1), Literal(2), Literal(3))) + withBasicCatalog { catalog => + val info1 = new ExpressionInfo("tempFunc1", "func1") + val tempFunc1 = (e: Seq[Expression]) => e.head + catalog.createTempFunction("func1", info1, tempFunc1, ignoreIfExists = false) + assert(catalog.lookupFunction( + FunctionIdentifier("func1"), Seq(Literal(1), Literal(2), Literal(3))) == Literal(1)) + catalog.dropTempFunction("func1", ignoreIfNotExists = false) + intercept[NoSuchFunctionException] { + catalog.lookupFunction(FunctionIdentifier("func1"), Seq(Literal(1), Literal(2), Literal(3))) + } } } test("list functions") { - val catalog = new SessionCatalog(newBasicCatalog()) - val info1 = new ExpressionInfo("tempFunc1", "func1") - val info2 = new ExpressionInfo("tempFunc2", "yes_me") - val tempFunc1 = (e: Seq[Expression]) => e.head - val tempFunc2 = (e: Seq[Expression]) => e.last - catalog.createFunction(newFunc("func2", Some("db2")), ignoreIfExists = false) - catalog.createFunction(newFunc("not_me", Some("db2")), ignoreIfExists = false) - catalog.createTempFunction("func1", info1, tempFunc1, ignoreIfExists = false) - catalog.createTempFunction("yes_me", info2, tempFunc2, ignoreIfExists = false) - assert(catalog.listFunctions("db1", "*").map(_._1).toSet == - Set(FunctionIdentifier("func1"), - FunctionIdentifier("yes_me"))) - assert(catalog.listFunctions("db2", "*").map(_._1).toSet == - Set(FunctionIdentifier("func1"), - FunctionIdentifier("yes_me"), - FunctionIdentifier("func1", Some("db2")), - FunctionIdentifier("func2", Some("db2")), - FunctionIdentifier("not_me", Some("db2")))) - assert(catalog.listFunctions("db2", "func*").map(_._1).toSet == - Set(FunctionIdentifier("func1"), - FunctionIdentifier("func1", Some("db2")), - FunctionIdentifier("func2", Some("db2")))) + withBasicCatalog { catalog => + val info1 = new ExpressionInfo("tempFunc1", "func1") + val info2 = new ExpressionInfo("tempFunc2", "yes_me") + val tempFunc1 = (e: Seq[Expression]) => e.head + val tempFunc2 = (e: Seq[Expression]) => e.last + catalog.createFunction(newFunc("func2", Some("db2")), ignoreIfExists = false) + catalog.createFunction(newFunc("not_me", Some("db2")), ignoreIfExists = false) + catalog.createTempFunction("func1", info1, tempFunc1, ignoreIfExists = false) + catalog.createTempFunction("yes_me", info2, tempFunc2, ignoreIfExists = false) + assert(catalog.listFunctions("db1", "*").map(_._1).toSet == + Set(FunctionIdentifier("func1"), + FunctionIdentifier("yes_me"))) + assert(catalog.listFunctions("db2", "*").map(_._1).toSet == + Set(FunctionIdentifier("func1"), + FunctionIdentifier("yes_me"), + FunctionIdentifier("func1", Some("db2")), + FunctionIdentifier("func2", Some("db2")), + FunctionIdentifier("not_me", Some("db2")))) + assert(catalog.listFunctions("db2", "func*").map(_._1).toSet == + Set(FunctionIdentifier("func1"), + FunctionIdentifier("func1", Some("db2")), + FunctionIdentifier("func2", Some("db2")))) + } } test("list functions when database does not exist") { - val catalog = new SessionCatalog(newBasicCatalog()) - intercept[NoSuchDatabaseException] { - catalog.listFunctions("unknown_db", "func*") + withBasicCatalog { catalog => + intercept[NoSuchDatabaseException] { + catalog.listFunctions("unknown_db", "func*") + } } } test("clone SessionCatalog - temp views") { - val externalCatalog = newEmptyCatalog() - val original = new SessionCatalog(externalCatalog) - val tempTable1 = Range(1, 10, 1, 10) - original.createTempView("copytest1", tempTable1, overrideIfExists = false) + withEmptyCatalog { original => + val tempTable1 = Range(1, 10, 1, 10) + original.createTempView("copytest1", tempTable1, overrideIfExists = false) - // check if tables copied over - val clone = original.newSessionCatalogWith( - SimpleCatalystConf(caseSensitiveAnalysis = true), - new Configuration(), - new SimpleFunctionRegistry, - CatalystSqlParser) - assert(original ne clone) - assert(clone.getTempView("copytest1") == Some(tempTable1)) + // check if tables copied over + val clone = original.newSessionCatalogWith( + SimpleCatalystConf(caseSensitiveAnalysis = true), + new Configuration(), + new SimpleFunctionRegistry, + CatalystSqlParser) + assert(original ne clone) + assert(clone.getTempView("copytest1") == Some(tempTable1)) - // check if clone and original independent - clone.dropTable(TableIdentifier("copytest1"), ignoreIfNotExists = false, purge = false) - assert(original.getTempView("copytest1") == Some(tempTable1)) + // check if clone and original independent + clone.dropTable(TableIdentifier("copytest1"), ignoreIfNotExists = false, purge = false) + assert(original.getTempView("copytest1") == Some(tempTable1)) - val tempTable2 = Range(1, 20, 2, 10) - original.createTempView("copytest2", tempTable2, overrideIfExists = false) - assert(clone.getTempView("copytest2").isEmpty) + val tempTable2 = Range(1, 20, 2, 10) + original.createTempView("copytest2", tempTable2, overrideIfExists = false) + assert(clone.getTempView("copytest2").isEmpty) + } } test("clone SessionCatalog - current db") { - val externalCatalog = newEmptyCatalog() - val db1 = "db1" - val db2 = "db2" - val db3 = "db3" - - externalCatalog.createDatabase(newDb(db1), ignoreIfExists = true) - externalCatalog.createDatabase(newDb(db2), ignoreIfExists = true) - externalCatalog.createDatabase(newDb(db3), ignoreIfExists = true) - - val original = new SessionCatalog(externalCatalog) - original.setCurrentDatabase(db1) - - // check if current db copied over - val clone = original.newSessionCatalogWith( - SimpleCatalystConf(caseSensitiveAnalysis = true), - new Configuration(), - new SimpleFunctionRegistry, - CatalystSqlParser) - assert(original ne clone) - assert(clone.getCurrentDatabase == db1) - - // check if clone and original independent - clone.setCurrentDatabase(db2) - assert(original.getCurrentDatabase == db1) - original.setCurrentDatabase(db3) - assert(clone.getCurrentDatabase == db2) + withEmptyCatalog { original => + val db1 = "db1" + val db2 = "db2" + val db3 = "db3" + + original.externalCatalog.createDatabase(newDb(db1), ignoreIfExists = true) + original.externalCatalog.createDatabase(newDb(db2), ignoreIfExists = true) + original.externalCatalog.createDatabase(newDb(db3), ignoreIfExists = true) + + original.setCurrentDatabase(db1) + + // check if current db copied over + val clone = original.newSessionCatalogWith( + SimpleCatalystConf(caseSensitiveAnalysis = true), + new Configuration(), + new SimpleFunctionRegistry, + CatalystSqlParser) + assert(original ne clone) + assert(clone.getCurrentDatabase == db1) + + // check if clone and original independent + clone.setCurrentDatabase(db2) + assert(original.getCurrentDatabase == db1) + original.setCurrentDatabase(db3) + assert(clone.getCurrentDatabase == db2) + } } test("SPARK-19737: detect undefined functions without triggering relation resolution") { @@ -1258,18 +1363,22 @@ class SessionCatalogSuite extends PlanTest { Seq(true, false) foreach { caseSensitive => val conf = SimpleCatalystConf(caseSensitive) val catalog = new SessionCatalog(newBasicCatalog(), new SimpleFunctionRegistry, conf) - val analyzer = new Analyzer(catalog, conf) - - // The analyzer should report the undefined function rather than the undefined table first. - val cause = intercept[AnalysisException] { - analyzer.execute( - UnresolvedRelation(TableIdentifier("undefined_table")).select( - UnresolvedFunction("undefined_fn", Nil, isDistinct = false) + try { + val analyzer = new Analyzer(catalog, conf) + + // The analyzer should report the undefined function rather than the undefined table first. + val cause = intercept[AnalysisException] { + analyzer.execute( + UnresolvedRelation(TableIdentifier("undefined_table")).select( + UnresolvedFunction("undefined_fn", Nil, isDistinct = false) + ) ) - ) - } + } - assert(cause.getMessage.contains("Undefined function: 'undefined_fn'")) + assert(cause.getMessage.contains("Undefined function: 'undefined_fn'")) + } finally { + catalog.reset() + } } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalSessionCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalSessionCatalogSuite.scala new file mode 100644 index 000000000000..285f35b0b0ea --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalSessionCatalogSuite.scala @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import org.apache.spark.sql.catalyst.catalog.{CatalogTestUtils, ExternalCatalog, SessionCatalogSuite} +import org.apache.spark.sql.hive.test.TestHiveSingleton + +class HiveExternalSessionCatalogSuite extends SessionCatalogSuite with TestHiveSingleton { + + protected override val isHiveExternalCatalog = true + + private val externalCatalog = { + val catalog = spark.sharedState.externalCatalog + catalog.asInstanceOf[HiveExternalCatalog].client.reset() + catalog + } + + protected val utils = new CatalogTestUtils { + override val tableInputFormat: String = "org.apache.hadoop.mapred.SequenceFileInputFormat" + override val tableOutputFormat: String = + "org.apache.hadoop.hive.ql.io.HiveSequenceFileOutputFormat" + override val defaultProvider: String = "hive" + override def newEmptyCatalog(): ExternalCatalog = externalCatalog + } +} From 2ea214dd05da929840c15891e908384cfa695ca8 Mon Sep 17 00:00:00 2001 From: Liwei Lin Date: Thu, 16 Mar 2017 13:05:36 -0700 Subject: [PATCH 053/512] [SPARK-19721][SS] Good error message for version mismatch in log files ## Problem There are several places where we write out version identifiers in various logs for structured streaming (usually `v1`). However, in the places where we check for this, we throw a confusing error message. ## What changes were proposed in this pull request? This patch made two major changes: 1. added a `parseVersion(...)` method, and based on this method, fixed the following places the way they did version checking (no other place needed to do this checking): ``` HDFSMetadataLog - CompactibleFileStreamLog ------------> fixed with this patch - FileStreamSourceLog ---------------> inherited the fix of `CompactibleFileStreamLog` - FileStreamSinkLog -----------------> inherited the fix of `CompactibleFileStreamLog` - OffsetSeqLog ------------------------> fixed with this patch - anonymous subclass in KafkaSource ---> fixed with this patch ``` 2. changed the type of `FileStreamSinkLog.VERSION`, `FileStreamSourceLog.VERSION` etc. from `String` to `Int`, so that we can identify newer versions via `version > 1` instead of `version != "v1"` - note this didn't break any backwards compatibility -- we are still writing out `"v1"` and reading back `"v1"` ## Exception message with this patch ``` java.lang.IllegalStateException: Failed to read log file /private/var/folders/nn/82rmvkk568sd8p3p8tb33trw0000gn/T/spark-86867b65-0069-4ef1-b0eb-d8bd258ff5b8/0. UnsupportedLogVersion: maximum supported log version is v1, but encountered v99. The log file was produced by a newer version of Spark and cannot be read by this version. Please upgrade. at org.apache.spark.sql.execution.streaming.HDFSMetadataLog.get(HDFSMetadataLog.scala:202) at org.apache.spark.sql.execution.streaming.OffsetSeqLogSuite$$anonfun$3$$anonfun$apply$mcV$sp$2.apply(OffsetSeqLogSuite.scala:78) at org.apache.spark.sql.execution.streaming.OffsetSeqLogSuite$$anonfun$3$$anonfun$apply$mcV$sp$2.apply(OffsetSeqLogSuite.scala:75) at org.apache.spark.sql.test.SQLTestUtils$class.withTempDir(SQLTestUtils.scala:133) at org.apache.spark.sql.execution.streaming.OffsetSeqLogSuite.withTempDir(OffsetSeqLogSuite.scala:26) at org.apache.spark.sql.execution.streaming.OffsetSeqLogSuite$$anonfun$3.apply$mcV$sp(OffsetSeqLogSuite.scala:75) at org.apache.spark.sql.execution.streaming.OffsetSeqLogSuite$$anonfun$3.apply(OffsetSeqLogSuite.scala:75) at org.apache.spark.sql.execution.streaming.OffsetSeqLogSuite$$anonfun$3.apply(OffsetSeqLogSuite.scala:75) at org.scalatest.Transformer$$anonfun$apply$1.apply$mcV$sp(Transformer.scala:22) at org.scalatest.OutcomeOf$class.outcomeOf(OutcomeOf.scala:85) ``` ## How was this patch tested? unit tests Author: Liwei Lin Closes #17070 from lw-lin/better-msg. --- .../spark/sql/kafka010/KafkaSource.scala | 14 +++---- .../spark/sql/kafka010/KafkaSourceSuite.scala | 9 ++++- .../streaming/CompactibleFileStreamLog.scala | 9 ++--- .../streaming/FileStreamSinkLog.scala | 4 +- .../streaming/FileStreamSourceLog.scala | 4 +- .../execution/streaming/HDFSMetadataLog.scala | 36 +++++++++++++++++ .../execution/streaming/OffsetSeqLog.scala | 10 ++--- .../CompactibleFileStreamLogSuite.scala | 40 ++++++++++++++++--- .../streaming/FileStreamSinkLogSuite.scala | 8 ++-- .../streaming/HDFSMetadataLogSuite.scala | 27 +++++++++++++ .../streaming/OffsetSeqLogSuite.scala | 17 ++++++++ 11 files changed, 143 insertions(+), 35 deletions(-) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala index 92b5d91ba435..1fb0a338299b 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala @@ -100,7 +100,7 @@ private[kafka010] class KafkaSource( override def serialize(metadata: KafkaSourceOffset, out: OutputStream): Unit = { out.write(0) // A zero byte is written to support Spark 2.1.0 (SPARK-19517) val writer = new BufferedWriter(new OutputStreamWriter(out, StandardCharsets.UTF_8)) - writer.write(VERSION) + writer.write("v" + VERSION + "\n") writer.write(metadata.json) writer.flush } @@ -111,13 +111,13 @@ private[kafka010] class KafkaSource( // HDFSMetadataLog guarantees that it never creates a partial file. assert(content.length != 0) if (content(0) == 'v') { - if (content.startsWith(VERSION)) { - KafkaSourceOffset(SerializedOffset(content.substring(VERSION.length))) + val indexOfNewLine = content.indexOf("\n") + if (indexOfNewLine > 0) { + val version = parseVersion(content.substring(0, indexOfNewLine), VERSION) + KafkaSourceOffset(SerializedOffset(content.substring(indexOfNewLine + 1))) } else { - val versionInFile = content.substring(0, content.indexOf("\n")) throw new IllegalStateException( - s"Unsupported format. Expected version is ${VERSION.stripLineEnd} " + - s"but was $versionInFile. Please upgrade your Spark.") + s"Log file was malformed: failed to detect the log file version line.") } } else { // The log was generated by Spark 2.1.0 @@ -351,7 +351,7 @@ private[kafka010] object KafkaSource { | source option "failOnDataLoss" to "false". """.stripMargin - private val VERSION = "v1\n" + private[kafka010] val VERSION = 1 def getSortedExecutorList(sc: SparkContext): Array[String] = { val bm = sc.env.blockManager diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala index bf6aad671a18..7b6396e0291c 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala @@ -205,7 +205,7 @@ class KafkaSourceSuite extends KafkaSourceTest { override def serialize(metadata: KafkaSourceOffset, out: OutputStream): Unit = { out.write(0) val writer = new BufferedWriter(new OutputStreamWriter(out, UTF_8)) - writer.write(s"v0\n${metadata.json}") + writer.write(s"v99999\n${metadata.json}") writer.flush } } @@ -227,7 +227,12 @@ class KafkaSourceSuite extends KafkaSourceTest { source.getOffset.get // Read initial offset } - assert(e.getMessage.contains("Please upgrade your Spark")) + Seq( + s"maximum supported log version is v${KafkaSource.VERSION}, but encountered v99999", + "produced by a newer version of Spark and cannot be read by this version" + ).foreach { message => + assert(e.getMessage.contains(message)) + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLog.scala index 5a6f9e87f6ea..408c8f81f17b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLog.scala @@ -40,7 +40,7 @@ import org.apache.spark.sql.SparkSession * doing a compaction, it will read all old log files and merge them with the new batch. */ abstract class CompactibleFileStreamLog[T <: AnyRef : ClassTag]( - metadataLogVersion: String, + metadataLogVersion: Int, sparkSession: SparkSession, path: String) extends HDFSMetadataLog[Array[T]](sparkSession, path) { @@ -134,7 +134,7 @@ abstract class CompactibleFileStreamLog[T <: AnyRef : ClassTag]( override def serialize(logData: Array[T], out: OutputStream): Unit = { // called inside a try-finally where the underlying stream is closed in the caller - out.write(metadataLogVersion.getBytes(UTF_8)) + out.write(("v" + metadataLogVersion).getBytes(UTF_8)) logData.foreach { data => out.write('\n') out.write(Serialization.write(data).getBytes(UTF_8)) @@ -146,10 +146,7 @@ abstract class CompactibleFileStreamLog[T <: AnyRef : ClassTag]( if (!lines.hasNext) { throw new IllegalStateException("Incomplete log file") } - val version = lines.next() - if (version != metadataLogVersion) { - throw new IllegalStateException(s"Unknown log version: ${version}") - } + val version = parseVersion(lines.next(), metadataLogVersion) lines.map(Serialization.read[T]).toArray } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLog.scala index eb6eed87eca7..8d718b2164d2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLog.scala @@ -77,7 +77,7 @@ object SinkFileStatus { * (drops the deleted files). */ class FileStreamSinkLog( - metadataLogVersion: String, + metadataLogVersion: Int, sparkSession: SparkSession, path: String) extends CompactibleFileStreamLog[SinkFileStatus](metadataLogVersion, sparkSession, path) { @@ -106,7 +106,7 @@ class FileStreamSinkLog( } object FileStreamSinkLog { - val VERSION = "v1" + val VERSION = 1 val DELETE_ACTION = "delete" val ADD_ACTION = "add" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSourceLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSourceLog.scala index 81908c0cefdf..33e6a1d5d6e1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSourceLog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSourceLog.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.execution.streaming.FileStreamSource.FileEntry import org.apache.spark.sql.internal.SQLConf class FileStreamSourceLog( - metadataLogVersion: String, + metadataLogVersion: Int, sparkSession: SparkSession, path: String) extends CompactibleFileStreamLog[FileEntry](metadataLogVersion, sparkSession, path) { @@ -120,5 +120,5 @@ class FileStreamSourceLog( } object FileStreamSourceLog { - val VERSION = "v1" + val VERSION = 1 } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala index f9e1f7de9ec0..60ce64261c4a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala @@ -195,6 +195,11 @@ class HDFSMetadataLog[T <: AnyRef : ClassTag](sparkSession: SparkSession, path: val input = fileManager.open(batchMetadataFile) try { Some(deserialize(input)) + } catch { + case ise: IllegalStateException => + // re-throw the exception with the log file path added + throw new IllegalStateException( + s"Failed to read log file $batchMetadataFile. ${ise.getMessage}", ise) } finally { IOUtils.closeQuietly(input) } @@ -268,6 +273,37 @@ class HDFSMetadataLog[T <: AnyRef : ClassTag](sparkSession: SparkSession, path: new FileSystemManager(metadataPath, hadoopConf) } } + + /** + * Parse the log version from the given `text` -- will throw exception when the parsed version + * exceeds `maxSupportedVersion`, or when `text` is malformed (such as "xyz", "v", "v-1", + * "v123xyz" etc.) + */ + private[sql] def parseVersion(text: String, maxSupportedVersion: Int): Int = { + if (text.length > 0 && text(0) == 'v') { + val version = + try { + text.substring(1, text.length).toInt + } catch { + case _: NumberFormatException => + throw new IllegalStateException(s"Log file was malformed: failed to read correct log " + + s"version from $text.") + } + if (version > 0) { + if (version > maxSupportedVersion) { + throw new IllegalStateException(s"UnsupportedLogVersion: maximum supported log version " + + s"is v${maxSupportedVersion}, but encountered v$version. The log file was produced " + + s"by a newer version of Spark and cannot be read by this version. Please upgrade.") + } else { + return version + } + } + } + + // reaching here means we failed to read the correct log version + throw new IllegalStateException(s"Log file was malformed: failed to read correct log " + + s"version from $text.") + } } object HDFSMetadataLog { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLog.scala index 3210d8ad64e2..4f8cd116f610 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLog.scala @@ -55,10 +55,8 @@ class OffsetSeqLog(sparkSession: SparkSession, path: String) if (!lines.hasNext) { throw new IllegalStateException("Incomplete log file") } - val version = lines.next() - if (version != OffsetSeqLog.VERSION) { - throw new IllegalStateException(s"Unknown log version: ${version}") - } + + val version = parseVersion(lines.next(), OffsetSeqLog.VERSION) // read metadata val metadata = lines.next().trim match { @@ -70,7 +68,7 @@ class OffsetSeqLog(sparkSession: SparkSession, path: String) override protected def serialize(offsetSeq: OffsetSeq, out: OutputStream): Unit = { // called inside a try-finally where the underlying stream is closed in the caller - out.write(OffsetSeqLog.VERSION.getBytes(UTF_8)) + out.write(("v" + OffsetSeqLog.VERSION).getBytes(UTF_8)) // write metadata out.write('\n') @@ -88,6 +86,6 @@ class OffsetSeqLog(sparkSession: SparkSession, path: String) } object OffsetSeqLog { - private val VERSION = "v1" + private[streaming] val VERSION = 1 private val SERIALIZED_VOID_OFFSET = "-" } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLogSuite.scala index 24d92a96237e..20ac06f048c6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLogSuite.scala @@ -122,7 +122,7 @@ class CompactibleFileStreamLogSuite extends SparkFunSuite with SharedSQLContext defaultMinBatchesToRetain = 1, compactibleLog => { val logs = Array("entry_1", "entry_2", "entry_3") - val expected = s"""${FakeCompactibleFileStreamLog.VERSION} + val expected = s"""v${FakeCompactibleFileStreamLog.VERSION} |"entry_1" |"entry_2" |"entry_3"""".stripMargin @@ -132,7 +132,7 @@ class CompactibleFileStreamLogSuite extends SparkFunSuite with SharedSQLContext baos.reset() compactibleLog.serialize(Array(), baos) - assert(FakeCompactibleFileStreamLog.VERSION === baos.toString(UTF_8.name())) + assert(s"v${FakeCompactibleFileStreamLog.VERSION}" === baos.toString(UTF_8.name())) }) } @@ -142,7 +142,7 @@ class CompactibleFileStreamLogSuite extends SparkFunSuite with SharedSQLContext defaultCompactInterval = 3, defaultMinBatchesToRetain = 1, compactibleLog => { - val logs = s"""${FakeCompactibleFileStreamLog.VERSION} + val logs = s"""v${FakeCompactibleFileStreamLog.VERSION} |"entry_1" |"entry_2" |"entry_3"""".stripMargin @@ -152,10 +152,36 @@ class CompactibleFileStreamLogSuite extends SparkFunSuite with SharedSQLContext assert(Nil === compactibleLog.deserialize( - new ByteArrayInputStream(FakeCompactibleFileStreamLog.VERSION.getBytes(UTF_8)))) + new ByteArrayInputStream(s"v${FakeCompactibleFileStreamLog.VERSION}".getBytes(UTF_8)))) }) } + test("deserialization log written by future version") { + withTempDir { dir => + def newFakeCompactibleFileStreamLog(version: Int): FakeCompactibleFileStreamLog = + new FakeCompactibleFileStreamLog( + version, + _fileCleanupDelayMs = Long.MaxValue, // this param does not matter here in this test case + _defaultCompactInterval = 3, // this param does not matter here in this test case + _defaultMinBatchesToRetain = 1, // this param does not matter here in this test case + spark, + dir.getCanonicalPath) + + val writer = newFakeCompactibleFileStreamLog(version = 2) + val reader = newFakeCompactibleFileStreamLog(version = 1) + writer.add(0, Array("entry")) + val e = intercept[IllegalStateException] { + reader.get(0) + } + Seq( + "maximum supported log version is v1, but encountered v2", + "produced by a newer version of Spark and cannot be read by this version" + ).foreach { message => + assert(e.getMessage.contains(message)) + } + } + } + test("compact") { withFakeCompactibleFileStreamLog( fileCleanupDelayMs = Long.MaxValue, @@ -219,6 +245,7 @@ class CompactibleFileStreamLogSuite extends SparkFunSuite with SharedSQLContext ): Unit = { withTempDir { file => val compactibleLog = new FakeCompactibleFileStreamLog( + FakeCompactibleFileStreamLog.VERSION, fileCleanupDelayMs, defaultCompactInterval, defaultMinBatchesToRetain, @@ -230,17 +257,18 @@ class CompactibleFileStreamLogSuite extends SparkFunSuite with SharedSQLContext } object FakeCompactibleFileStreamLog { - val VERSION = "test_version" + val VERSION = 1 } class FakeCompactibleFileStreamLog( + metadataLogVersion: Int, _fileCleanupDelayMs: Long, _defaultCompactInterval: Int, _defaultMinBatchesToRetain: Int, sparkSession: SparkSession, path: String) extends CompactibleFileStreamLog[String]( - FakeCompactibleFileStreamLog.VERSION, + metadataLogVersion, sparkSession, path ) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLogSuite.scala index 340d2945acd4..dd3a414659c2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLogSuite.scala @@ -74,7 +74,7 @@ class FileStreamSinkLogSuite extends SparkFunSuite with SharedSQLContext { action = FileStreamSinkLog.ADD_ACTION)) // scalastyle:off - val expected = s"""$VERSION + val expected = s"""v$VERSION |{"path":"/a/b/x","size":100,"isDir":false,"modificationTime":1000,"blockReplication":1,"blockSize":10000,"action":"add"} |{"path":"/a/b/y","size":200,"isDir":false,"modificationTime":2000,"blockReplication":2,"blockSize":20000,"action":"delete"} |{"path":"/a/b/z","size":300,"isDir":false,"modificationTime":3000,"blockReplication":3,"blockSize":30000,"action":"add"}""".stripMargin @@ -84,14 +84,14 @@ class FileStreamSinkLogSuite extends SparkFunSuite with SharedSQLContext { assert(expected === baos.toString(UTF_8.name())) baos.reset() sinkLog.serialize(Array(), baos) - assert(VERSION === baos.toString(UTF_8.name())) + assert(s"v$VERSION" === baos.toString(UTF_8.name())) } } test("deserialize") { withFileStreamSinkLog { sinkLog => // scalastyle:off - val logs = s"""$VERSION + val logs = s"""v$VERSION |{"path":"/a/b/x","size":100,"isDir":false,"modificationTime":1000,"blockReplication":1,"blockSize":10000,"action":"add"} |{"path":"/a/b/y","size":200,"isDir":false,"modificationTime":2000,"blockReplication":2,"blockSize":20000,"action":"delete"} |{"path":"/a/b/z","size":300,"isDir":false,"modificationTime":3000,"blockReplication":3,"blockSize":30000,"action":"add"}""".stripMargin @@ -125,7 +125,7 @@ class FileStreamSinkLogSuite extends SparkFunSuite with SharedSQLContext { assert(expected === sinkLog.deserialize(new ByteArrayInputStream(logs.getBytes(UTF_8)))) - assert(Nil === sinkLog.deserialize(new ByteArrayInputStream(VERSION.getBytes(UTF_8)))) + assert(Nil === sinkLog.deserialize(new ByteArrayInputStream(s"v$VERSION".getBytes(UTF_8)))) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala index 55750b920298..662c4466b21b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala @@ -127,6 +127,33 @@ class HDFSMetadataLogSuite extends SparkFunSuite with SharedSQLContext { } } + test("HDFSMetadataLog: parseVersion") { + withTempDir { dir => + val metadataLog = new HDFSMetadataLog[String](spark, dir.getAbsolutePath) + def assertLogFileMalformed(func: => Int): Unit = { + val e = intercept[IllegalStateException] { func } + assert(e.getMessage.contains(s"Log file was malformed: failed to read correct log version")) + } + assertLogFileMalformed { metadataLog.parseVersion("", 100) } + assertLogFileMalformed { metadataLog.parseVersion("xyz", 100) } + assertLogFileMalformed { metadataLog.parseVersion("v10.x", 100) } + assertLogFileMalformed { metadataLog.parseVersion("10", 100) } + assertLogFileMalformed { metadataLog.parseVersion("v0", 100) } + assertLogFileMalformed { metadataLog.parseVersion("v-10", 100) } + + assert(metadataLog.parseVersion("v10", 10) === 10) + assert(metadataLog.parseVersion("v10", 100) === 10) + + val e = intercept[IllegalStateException] { metadataLog.parseVersion("v200", 100) } + Seq( + "maximum supported log version is v100, but encountered v200", + "produced by a newer version of Spark and cannot be read by this version" + ).foreach { message => + assert(e.getMessage.contains(message)) + } + } + } + test("HDFSMetadataLog: restart") { withTempDir { temp => val metadataLog = new HDFSMetadataLog[String](spark, temp.getAbsolutePath) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLogSuite.scala index 5ae8b2484d2e..f7f0dade8717 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLogSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.streaming import java.io.File import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.util.stringToFile import org.apache.spark.sql.test.SharedSQLContext class OffsetSeqLogSuite extends SparkFunSuite with SharedSQLContext { @@ -70,6 +71,22 @@ class OffsetSeqLogSuite extends SparkFunSuite with SharedSQLContext { } } + test("deserialization log written by future version") { + withTempDir { dir => + stringToFile(new File(dir, "0"), "v99999") + val log = new OffsetSeqLog(spark, dir.getCanonicalPath) + val e = intercept[IllegalStateException] { + log.get(0) + } + Seq( + s"maximum supported log version is v${OffsetSeqLog.VERSION}, but encountered v99999", + "produced by a newer version of Spark and cannot be read by this version" + ).foreach { message => + assert(e.getMessage.contains(message)) + } + } + } + test("read Spark 2.1.0 log format") { val (batchId, offsetSeq) = readFromResource("offset-log-version-2.1.0") assert(batchId === 0) From 4c3200546c5c55e671988a957011417ba76a0600 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Thu, 16 Mar 2017 17:10:15 -0700 Subject: [PATCH 054/512] [SPARK-19635][ML] DataFrame-based API for chi square test ## What changes were proposed in this pull request? Wrapper taking and return a DataFrame ## How was this patch tested? Copied unit tests from RDD-based API Author: Joseph K. Bradley Closes #17110 from jkbradley/df-hypotests. --- .../org/apache/spark/ml/stat/ChiSquare.scala | 81 +++++++++++++++ .../spark/mllib/stat/test/ChiSqTest.scala | 8 +- .../apache/spark/ml/stat/ChiSquareSuite.scala | 98 +++++++++++++++++++ .../mllib/stat/HypothesisTestSuite.scala | 11 ++- 4 files changed, 192 insertions(+), 6 deletions(-) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/stat/ChiSquare.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/stat/ChiSquareSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/stat/ChiSquare.scala b/mllib/src/main/scala/org/apache/spark/ml/stat/ChiSquare.scala new file mode 100644 index 000000000000..c3865ce6a9e2 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/stat/ChiSquare.scala @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.stat + +import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.ml.linalg.{Vector, Vectors, VectorUDT} +import org.apache.spark.ml.util.SchemaUtils +import org.apache.spark.mllib.linalg.{Vectors => OldVectors} +import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint} +import org.apache.spark.mllib.stat.{Statistics => OldStatistics} +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.functions.col + + +/** + * :: Experimental :: + * + * Chi-square hypothesis testing for categorical data. + * + * See Wikipedia for more information + * on the Chi-squared test. + */ +@Experimental +@Since("2.2.0") +object ChiSquare { + + /** Used to construct output schema of tests */ + private case class ChiSquareResult( + pValues: Vector, + degreesOfFreedom: Array[Int], + statistics: Vector) + + /** + * Conduct Pearson's independence test for every feature against the label across the input RDD. + * For each feature, the (feature, label) pairs are converted into a contingency matrix for which + * the Chi-squared statistic is computed. All label and feature values must be categorical. + * + * The null hypothesis is that the occurrence of the outcomes is statistically independent. + * + * @param dataset DataFrame of categorical labels and categorical features. + * Real-valued features will be treated as categorical for each distinct value. + * @param featuresCol Name of features column in dataset, of type `Vector` (`VectorUDT`) + * @param labelCol Name of label column in dataset, of any numerical type + * @return DataFrame containing the test result for every feature against the label. + * This DataFrame will contain a single Row with the following fields: + * - `pValues: Vector` + * - `degreesOfFreedom: Array[Int]` + * - `statistics: Vector` + * Each of these fields has one value per feature. + */ + @Since("2.2.0") + def test(dataset: DataFrame, featuresCol: String, labelCol: String): DataFrame = { + val spark = dataset.sparkSession + import spark.implicits._ + + SchemaUtils.checkColumnType(dataset.schema, featuresCol, new VectorUDT) + SchemaUtils.checkNumericType(dataset.schema, labelCol) + val rdd = dataset.select(col(labelCol).cast("double"), col(featuresCol)).as[(Double, Vector)] + .rdd.map { case (label, features) => OldLabeledPoint(label, OldVectors.fromML(features)) } + val testResults = OldStatistics.chiSqTest(rdd) + val pValues: Vector = Vectors.dense(testResults.map(_.pValue)) + val degreesOfFreedom: Array[Int] = testResults.map(_.degreesOfFreedom) + val statistics: Vector = Vectors.dense(testResults.map(_.statistic)) + spark.createDataFrame(Seq(ChiSquareResult(pValues, degreesOfFreedom, statistics))) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/ChiSqTest.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/ChiSqTest.scala index 9a63b8a5d63d..ee51248e5355 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/ChiSqTest.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/ChiSqTest.scala @@ -41,7 +41,7 @@ import org.apache.spark.rdd.RDD * * More information on Chi-squared test: http://en.wikipedia.org/wiki/Chi-squared_test */ -private[stat] object ChiSqTest extends Logging { +private[spark] object ChiSqTest extends Logging { /** * @param name String name for the method. @@ -70,6 +70,11 @@ private[stat] object ChiSqTest extends Logging { } } + /** + * Max number of categories when indexing labels and features + */ + private[spark] val maxCategories: Int = 10000 + /** * Conduct Pearson's independence test for each feature against the label across the input RDD. * The contingency table is constructed from the raw (feature, label) pairs and used to conduct @@ -78,7 +83,6 @@ private[stat] object ChiSqTest extends Logging { */ def chiSquaredFeatures(data: RDD[LabeledPoint], methodName: String = PEARSON.name): Array[ChiSqTestResult] = { - val maxCategories = 10000 val numCols = data.first().features.size val results = new Array[ChiSqTestResult](numCols) var labels: Map[Double, Int] = null diff --git a/mllib/src/test/scala/org/apache/spark/ml/stat/ChiSquareSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/stat/ChiSquareSuite.scala new file mode 100644 index 000000000000..b4bed82e4d00 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/stat/ChiSquareSuite.scala @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.stat + +import java.util.Random + +import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.ml.feature.LabeledPoint +import org.apache.spark.ml.linalg.{Vector, Vectors} +import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.ml.util.TestingUtils._ +import org.apache.spark.mllib.stat.test.ChiSqTest +import org.apache.spark.mllib.util.MLlibTestSparkContext + +class ChiSquareSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + + import testImplicits._ + + test("test DataFrame of labeled points") { + // labels: 1.0 (2 / 6), 0.0 (4 / 6) + // feature1: 0.5 (1 / 6), 1.5 (2 / 6), 3.5 (3 / 6) + // feature2: 10.0 (1 / 6), 20.0 (1 / 6), 30.0 (2 / 6), 40.0 (2 / 6) + val data = Seq( + LabeledPoint(0.0, Vectors.dense(0.5, 10.0)), + LabeledPoint(0.0, Vectors.dense(1.5, 20.0)), + LabeledPoint(1.0, Vectors.dense(1.5, 30.0)), + LabeledPoint(0.0, Vectors.dense(3.5, 30.0)), + LabeledPoint(0.0, Vectors.dense(3.5, 40.0)), + LabeledPoint(1.0, Vectors.dense(3.5, 40.0))) + for (numParts <- List(2, 4, 6, 8)) { + val df = spark.createDataFrame(sc.parallelize(data, numParts)) + val chi = ChiSquare.test(df, "features", "label") + val (pValues: Vector, degreesOfFreedom: Array[Int], statistics: Vector) = + chi.select("pValues", "degreesOfFreedom", "statistics") + .as[(Vector, Array[Int], Vector)].head() + assert(pValues ~== Vectors.dense(0.6873, 0.6823) relTol 1e-4) + assert(degreesOfFreedom === Array(2, 3)) + assert(statistics ~== Vectors.dense(0.75, 1.5) relTol 1e-4) + } + } + + test("large number of features (SPARK-3087)") { + // Test that the right number of results is returned + val numCols = 1001 + val sparseData = Array( + LabeledPoint(0.0, Vectors.sparse(numCols, Seq((100, 2.0)))), + LabeledPoint(0.1, Vectors.sparse(numCols, Seq((200, 1.0))))) + val df = spark.createDataFrame(sparseData) + val chi = ChiSquare.test(df, "features", "label") + val (pValues: Vector, degreesOfFreedom: Array[Int], statistics: Vector) = + chi.select("pValues", "degreesOfFreedom", "statistics") + .as[(Vector, Array[Int], Vector)].head() + assert(pValues.size === numCols) + assert(degreesOfFreedom.length === numCols) + assert(statistics.size === numCols) + assert(pValues(1000) !== null) // SPARK-3087 + } + + test("fail on continuous features or labels") { + val tooManyCategories: Int = 100000 + assert(tooManyCategories > ChiSqTest.maxCategories, "This unit test requires that " + + "tooManyCategories be large enough to cause ChiSqTest to throw an exception.") + + val random = new Random(11L) + val continuousLabel = Seq.fill(tooManyCategories)( + LabeledPoint(random.nextDouble(), Vectors.dense(random.nextInt(2)))) + withClue("ChiSquare should throw an exception when given a continuous-valued label") { + intercept[SparkException] { + val df = spark.createDataFrame(continuousLabel) + ChiSquare.test(df, "features", "label") + } + } + val continuousFeature = Seq.fill(tooManyCategories)( + LabeledPoint(random.nextInt(2), Vectors.dense(random.nextDouble()))) + withClue("ChiSquare should throw an exception when given continuous-valued features") { + intercept[SparkException] { + val df = spark.createDataFrame(continuousFeature) + ChiSquare.test(df, "features", "label") + } + } + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/HypothesisTestSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/HypothesisTestSuite.scala index 46fcebe13274..992b87656189 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/stat/HypothesisTestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/HypothesisTestSuite.scala @@ -145,14 +145,17 @@ class HypothesisTestSuite extends SparkFunSuite with MLlibTestSparkContext { assert(chi(1000) != null) // SPARK-3087 // Detect continuous features or labels + val tooManyCategories: Int = 100000 + assert(tooManyCategories > ChiSqTest.maxCategories, "This unit test requires that " + + "tooManyCategories be large enough to cause ChiSqTest to throw an exception.") val random = new Random(11L) - val continuousLabel = - Seq.fill(100000)(LabeledPoint(random.nextDouble(), Vectors.dense(random.nextInt(2)))) + val continuousLabel = Seq.fill(tooManyCategories)( + LabeledPoint(random.nextDouble(), Vectors.dense(random.nextInt(2)))) intercept[SparkException] { Statistics.chiSqTest(sc.parallelize(continuousLabel, 2)) } - val continuousFeature = - Seq.fill(100000)(LabeledPoint(random.nextInt(2), Vectors.dense(random.nextDouble()))) + val continuousFeature = Seq.fill(tooManyCategories)( + LabeledPoint(random.nextInt(2), Vectors.dense(random.nextDouble()))) intercept[SparkException] { Statistics.chiSqTest(sc.parallelize(continuousFeature, 2)) } From 8537c00e0a17eff2a8c6745fbdd1d08873c0434d Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 16 Mar 2017 18:31:57 -0700 Subject: [PATCH 055/512] [SPARK-19987][SQL] Pass all filters into FileIndex ## What changes were proposed in this pull request? This is a tiny teeny refactoring to pass data filters also to the FileIndex, so FileIndex can have a more global view on predicates. ## How was this patch tested? Change should be covered by existing test cases. Author: Reynold Xin Closes #17322 from rxin/SPARK-19987. --- .../sql/execution/DataSourceScanExec.scala | 23 +++++++++++-------- .../execution/OptimizeMetadataOnlyQuery.scala | 2 +- .../datasources/CatalogFileIndex.scala | 5 ++-- .../sql/execution/datasources/FileIndex.scala | 15 ++++++++---- .../datasources/FileSourceStrategy.scala | 5 +--- .../PartitioningAwareFileIndex.scala | 8 ++++--- .../spark/sql/hive/HiveMetastoreCatalog.scala | 4 +--- 7 files changed, 35 insertions(+), 27 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index 8ebad676ca31..bfe9c8e351ab 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -23,18 +23,18 @@ import org.apache.commons.lang3.StringUtils import org.apache.hadoop.fs.{BlockLocation, FileStatus, LocatedFileStatus, Path} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{AnalysisException, SparkSession} +import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier} import org.apache.spark.sql.catalyst.catalog.BucketSpec import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, UnknownPartitioning} import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.parquet.{ParquetFileFormat => ParquetSource} import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.sources.{BaseRelation, Filter} -import org.apache.spark.sql.types.{DataType, StructType} +import org.apache.spark.sql.sources.BaseRelation +import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils trait DataSourceScanExec extends LeafExecNode with CodegenSupport { @@ -135,7 +135,7 @@ case class RowDataSourceScanExec( * @param output Output attributes of the scan. * @param outputSchema Output schema of the scan. * @param partitionFilters Predicates to use for partition pruning. - * @param dataFilters Data source filters to use for filtering data within partitions. + * @param dataFilters Filters on non-partition columns. * @param metastoreTableIdentifier identifier for the table in the metastore. */ case class FileSourceScanExec( @@ -143,7 +143,7 @@ case class FileSourceScanExec( output: Seq[Attribute], outputSchema: StructType, partitionFilters: Seq[Expression], - dataFilters: Seq[Filter], + dataFilters: Seq[Expression], override val metastoreTableIdentifier: Option[TableIdentifier]) extends DataSourceScanExec with ColumnarBatchScan { @@ -156,7 +156,8 @@ case class FileSourceScanExec( false } - @transient private lazy val selectedPartitions = relation.location.listFiles(partitionFilters) + @transient private lazy val selectedPartitions = + relation.location.listFiles(partitionFilters, dataFilters) override val (outputPartitioning, outputOrdering): (Partitioning, Seq[SortOrder]) = { val bucketSpec = if (relation.sparkSession.sessionState.conf.bucketingEnabled) { @@ -225,6 +226,10 @@ case class FileSourceScanExec( } } + @transient + private val pushedDownFilters = dataFilters.flatMap(DataSourceStrategy.translateFilter) + logInfo(s"Pushed Filters: ${pushedDownFilters.mkString(",")}") + // These metadata values make scan plans uniquely identifiable for equality checking. override val metadata: Map[String, String] = { def seqToString(seq: Seq[Any]) = seq.mkString("[", ", ", "]") @@ -237,7 +242,7 @@ case class FileSourceScanExec( "ReadSchema" -> outputSchema.catalogString, "Batched" -> supportsBatch.toString, "PartitionFilters" -> seqToString(partitionFilters), - "PushedFilters" -> seqToString(dataFilters), + "PushedFilters" -> seqToString(pushedDownFilters), "Location" -> locationDesc) val withOptPartitionCount = relation.partitionSchemaOption.map { _ => @@ -255,7 +260,7 @@ case class FileSourceScanExec( dataSchema = relation.dataSchema, partitionSchema = relation.partitionSchema, requiredSchema = outputSchema, - filters = dataFilters, + filters = pushedDownFilters, options = relation.options, hadoopConf = relation.sparkSession.sessionState.newHadoopConfWithOptions(relation.options)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala index 769deb1890b6..3c046ce49428 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala @@ -98,7 +98,7 @@ case class OptimizeMetadataOnlyQuery( relation match { case l @ LogicalRelation(fsRelation: HadoopFsRelation, _, _) => val partAttrs = getPartitionAttrs(fsRelation.partitionSchema.map(_.name), l) - val partitionData = fsRelation.location.listFiles(filters = Nil) + val partitionData = fsRelation.location.listFiles(Nil, Nil) LocalRelation(partAttrs, partitionData.map(_.values)) case relation: CatalogRelation => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CatalogFileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CatalogFileIndex.scala index d6c4b97ebd08..db0254f8d558 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CatalogFileIndex.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CatalogFileIndex.scala @@ -54,8 +54,9 @@ class CatalogFileIndex( override def rootPaths: Seq[Path] = baseLocation.map(new Path(_)).toSeq - override def listFiles(filters: Seq[Expression]): Seq[PartitionDirectory] = { - filterPartitions(filters).listFiles(Nil) + override def listFiles( + partitionFilters: Seq[Expression], dataFilters: Seq[Expression]): Seq[PartitionDirectory] = { + filterPartitions(partitionFilters).listFiles(Nil, dataFilters) } override def refresh(): Unit = fileStatusCache.invalidateAll() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileIndex.scala index 277223d52ec5..6b99d38fe572 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileIndex.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileIndex.scala @@ -46,12 +46,17 @@ trait FileIndex { * Returns all valid files grouped into partitions when the data is partitioned. If the data is * unpartitioned, this will return a single partition with no partition values. * - * @param filters The filters used to prune which partitions are returned. These filters must - * only refer to partition columns and this method will only return files - * where these predicates are guaranteed to evaluate to `true`. Thus, these - * filters will not need to be evaluated again on the returned data. + * @param partitionFilters The filters used to prune which partitions are returned. These filters + * must only refer to partition columns and this method will only return + * files where these predicates are guaranteed to evaluate to `true`. + * Thus, these filters will not need to be evaluated again on the + * returned data. + * @param dataFilters Filters that can be applied on non-partitioned columns. The implementation + * does not need to guarantee these filters are applied, i.e. the execution + * engine will ensure these filters are still applied on the returned files. */ - def listFiles(filters: Seq[Expression]): Seq[PartitionDirectory] + def listFiles( + partitionFilters: Seq[Expression], dataFilters: Seq[Expression]): Seq[PartitionDirectory] /** * Returns the list of files that will be read when scanning this relation. This call may be diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala index 26e1380eca49..17f7e0e601c0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala @@ -100,9 +100,6 @@ object FileSourceStrategy extends Strategy with Logging { val outputSchema = readDataColumns.toStructType logInfo(s"Output Data Schema: ${outputSchema.simpleString(5)}") - val pushedDownFilters = dataFilters.flatMap(DataSourceStrategy.translateFilter) - logInfo(s"Pushed Filters: ${pushedDownFilters.mkString(",")}") - val outputAttributes = readDataColumns ++ partitionColumns val scan = @@ -111,7 +108,7 @@ object FileSourceStrategy extends Strategy with Logging { outputAttributes, outputSchema, partitionKeyFilters.toSeq, - pushedDownFilters, + dataFilters, table.map(_.identifier)) val afterScanFilter = afterScanFilters.toSeq.reduceOption(expressions.And) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala index db8bbc52aaf4..71500a010581 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala @@ -54,17 +54,19 @@ abstract class PartitioningAwareFileIndex( override def partitionSchema: StructType = partitionSpec().partitionColumns - protected val hadoopConf = sparkSession.sessionState.newHadoopConfWithOptions(parameters) + protected val hadoopConf: Configuration = + sparkSession.sessionState.newHadoopConfWithOptions(parameters) protected def leafFiles: mutable.LinkedHashMap[Path, FileStatus] protected def leafDirToChildrenFiles: Map[Path, Array[FileStatus]] - override def listFiles(filters: Seq[Expression]): Seq[PartitionDirectory] = { + override def listFiles( + partitionFilters: Seq[Expression], dataFilters: Seq[Expression]): Seq[PartitionDirectory] = { val selectedPartitions = if (partitionSpec().partitionColumns.isEmpty) { PartitionDirectory(InternalRow.empty, allFiles().filter(f => isDataPath(f.getPath))) :: Nil } else { - prunePartitions(filters, partitionSpec()).map { + prunePartitions(partitionFilters, partitionSpec()).map { case PartitionPath(values, path) => val files: Seq[FileStatus] = leafDirToChildrenFiles.get(path) match { case Some(existingDir) => diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 9f0d1ceb28fc..2e060ab9f680 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.hive -import java.net.URI - import scala.util.control.NonFatal import com.google.common.util.concurrent.Striped @@ -248,7 +246,7 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log .inferSchema( sparkSession, options, - fileIndex.listFiles(Nil).flatMap(_.files)) + fileIndex.listFiles(Nil, Nil).flatMap(_.files)) .map(mergeWithMetastoreSchema(relation.tableMeta.schema, _)) inferredSchema match { From 13538cf3dd089222c7e12a3cd6e72ac836fa51ac Mon Sep 17 00:00:00 2001 From: Andrew Ray Date: Fri, 17 Mar 2017 16:43:42 +0800 Subject: [PATCH 056/512] [SPARK-19882][SQL] Pivot with null as a distinct pivot value throws NPE ## What changes were proposed in this pull request? Allows null values of the pivot column to be included in the pivot values list without throwing NPE Note this PR was made as an alternative to #17224 but preserves the two phase aggregate operation that is needed for good performance. ## How was this patch tested? Additional unit test Author: Andrew Ray Closes #17226 from aray/pivot-null. --- .../spark/sql/catalyst/analysis/Analyzer.scala | 2 +- .../expressions/aggregate/PivotFirst.scala | 18 +++++++++--------- .../apache/spark/sql/DataFramePivotSuite.scala | 14 ++++++++++++++ 3 files changed, 24 insertions(+), 10 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 68a4746a54d9..8cf407382619 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -524,7 +524,7 @@ class Analyzer( } else { val pivotAggregates: Seq[NamedExpression] = pivotValues.flatMap { value => def ifExpr(expr: Expression) = { - If(EqualTo(pivotColumn, value), expr, Literal(null)) + If(EqualNullSafe(pivotColumn, value), expr, Literal(null)) } aggregates.map { aggregate => val filteredAggregate = aggregate.transformDown { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PivotFirst.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PivotFirst.scala index 9ad31243e412..523714869242 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PivotFirst.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PivotFirst.scala @@ -91,14 +91,12 @@ case class PivotFirst( override def update(mutableAggBuffer: InternalRow, inputRow: InternalRow): Unit = { val pivotColValue = pivotColumn.eval(inputRow) - if (pivotColValue != null) { - // We ignore rows whose pivot column value is not in the list of pivot column values. - val index = pivotIndex.getOrElse(pivotColValue, -1) - if (index >= 0) { - val value = valueColumn.eval(inputRow) - if (value != null) { - updateRow(mutableAggBuffer, mutableAggBufferOffset + index, value) - } + // We ignore rows whose pivot column value is not in the list of pivot column values. + val index = pivotIndex.getOrElse(pivotColValue, -1) + if (index >= 0) { + val value = valueColumn.eval(inputRow) + if (value != null) { + updateRow(mutableAggBuffer, mutableAggBufferOffset + index, value) } } } @@ -140,7 +138,9 @@ case class PivotFirst( override val aggBufferAttributes: Seq[AttributeReference] = - pivotIndex.toList.sortBy(_._2).map(kv => AttributeReference(kv._1.toString, valueDataType)()) + pivotIndex.toList.sortBy(_._2).map { kv => + AttributeReference(Option(kv._1).getOrElse("null").toString, valueDataType)() + } override val aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala index 51ffe3417271..ca3cb5676742 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala @@ -216,4 +216,18 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext{ Row("d", 15000.0, 48000.0) :: Row("J", 20000.0, 30000.0) :: Nil ) } + + test("pivot with null should not throw NPE") { + checkAnswer( + Seq(Tuple1(None), Tuple1(Some(1))).toDF("a").groupBy($"a").pivot("a").count(), + Row(null, 1, null) :: Row(1, null, 1) :: Nil) + } + + test("pivot with null and aggregate type not supported by PivotFirst returns correct result") { + checkAnswer( + Seq(Tuple1(None), Tuple1(Some(1))).toDF("a") + .withColumn("b", expr("array(a, 7)")) + .groupBy($"a").pivot("a").agg(min($"b")), + Row(null, Seq(null, 7), null) :: Row(1, null, Seq(1, 7)) :: Nil) + } } From 7b5d873aef672aa0aee41e338bab7428101e1ad3 Mon Sep 17 00:00:00 2001 From: Sital Kedia Date: Fri, 17 Mar 2017 09:33:45 -0500 Subject: [PATCH 057/512] [SPARK-13369] Add config for number of consecutive fetch failures The previously hardcoded max 4 retries per stage is not suitable for all cluster configurations. Since spark retries a stage at the sign of the first fetch failure, you can easily end up with many stage retries to discover all the failures. In particular, two scenarios this value should change are (1) if there are more than 4 executors per node; in that case, it may take 4 retries to discover the problem with each executor on the node and (2) during cluster maintenance on large clusters, where multiple machines are serviced at once, but you also cannot afford total cluster downtime. By making this value configurable, cluster managers can tune this value to something more appropriate to their cluster configuration. Unit tests Author: Sital Kedia Closes #17307 from sitalkedia/SPARK-13369. --- .../apache/spark/scheduler/DAGScheduler.scala | 15 +++++++++++++-- .../org/apache/spark/scheduler/Stage.scala | 18 +----------------- .../spark/scheduler/DAGSchedulerSuite.scala | 16 ++++++++-------- docs/configuration.md | 5 +++++ 4 files changed, 27 insertions(+), 27 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 692ed8083475..d944f268755d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -187,6 +187,13 @@ class DAGScheduler( /** If enabled, FetchFailed will not cause stage retry, in order to surface the problem. */ private val disallowStageRetryForTest = sc.getConf.getBoolean("spark.test.noStageRetry", false) + /** + * Number of consecutive stage attempts allowed before a stage is aborted. + */ + private[scheduler] val maxConsecutiveStageAttempts = + sc.getConf.getInt("spark.stage.maxConsecutiveAttempts", + DAGScheduler.DEFAULT_MAX_CONSECUTIVE_STAGE_ATTEMPTS) + private val messageScheduler = ThreadUtils.newDaemonSingleThreadScheduledExecutor("dag-scheduler-message") @@ -1282,8 +1289,9 @@ class DAGScheduler( s"longer running") } + failedStage.fetchFailedAttemptIds.add(task.stageAttemptId) val shouldAbortStage = - failedStage.failedOnFetchAndShouldAbort(task.stageAttemptId) || + failedStage.fetchFailedAttemptIds.size >= maxConsecutiveStageAttempts || disallowStageRetryForTest if (shouldAbortStage) { @@ -1292,7 +1300,7 @@ class DAGScheduler( } else { s"""$failedStage (${failedStage.name}) |has failed the maximum allowable number of - |times: ${Stage.MAX_CONSECUTIVE_FETCH_FAILURES}. + |times: $maxConsecutiveStageAttempts. |Most recent failure reason: $failureMessage""".stripMargin.replaceAll("\n", " ") } abortStage(failedStage, abortMessage, None) @@ -1726,4 +1734,7 @@ private[spark] object DAGScheduler { // this is a simplistic way to avoid resubmitting tasks in the non-fetchable map stage one by one // as more failure events come in val RESUBMIT_TIMEOUT = 200 + + // Number of consecutive stage attempts allowed before a stage is aborted + val DEFAULT_MAX_CONSECUTIVE_STAGE_ATTEMPTS = 4 } diff --git a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala index 32e5df6d75f4..290fd073caf2 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala @@ -87,23 +87,12 @@ private[scheduler] abstract class Stage( * We keep track of each attempt ID that has failed to avoid recording duplicate failures if * multiple tasks from the same stage attempt fail (SPARK-5945). */ - private val fetchFailedAttemptIds = new HashSet[Int] + val fetchFailedAttemptIds = new HashSet[Int] private[scheduler] def clearFailures() : Unit = { fetchFailedAttemptIds.clear() } - /** - * Check whether we should abort the failedStage due to multiple consecutive fetch failures. - * - * This method updates the running set of failed stage attempts and returns - * true if the number of failures exceeds the allowable number of failures. - */ - private[scheduler] def failedOnFetchAndShouldAbort(stageAttemptId: Int): Boolean = { - fetchFailedAttemptIds.add(stageAttemptId) - fetchFailedAttemptIds.size >= Stage.MAX_CONSECUTIVE_FETCH_FAILURES - } - /** Creates a new attempt for this stage by creating a new StageInfo with a new attempt ID. */ def makeNewStageAttempt( numPartitionsToCompute: Int, @@ -128,8 +117,3 @@ private[scheduler] abstract class Stage( /** Returns the sequence of partition ids that are missing (i.e. needs to be computed). */ def findMissingPartitions(): Seq[Int] } - -private[scheduler] object Stage { - // The number of consecutive failures allowed before a stage is aborted - val MAX_CONSECUTIVE_FETCH_FAILURES = 4 -} diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 8eaf9dfcf49b..dfad5db68a91 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -801,7 +801,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou val reduceRdd = new MyRDD(sc, 2, List(shuffleDep), tracker = mapOutputTracker) submit(reduceRdd, Array(0, 1)) - for (attempt <- 0 until Stage.MAX_CONSECUTIVE_FETCH_FAILURES) { + for (attempt <- 0 until scheduler.maxConsecutiveStageAttempts) { // Complete all the tasks for the current attempt of stage 0 successfully completeShuffleMapStageSuccessfully(0, attempt, numShufflePartitions = 2) @@ -813,7 +813,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou // map output, for the next iteration through the loop scheduler.resubmitFailedStages() - if (attempt < Stage.MAX_CONSECUTIVE_FETCH_FAILURES - 1) { + if (attempt < scheduler.maxConsecutiveStageAttempts - 1) { assert(scheduler.runningStages.nonEmpty) assert(!ended) } else { @@ -847,11 +847,11 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou // In the first two iterations, Stage 0 succeeds and stage 1 fails. In the next two iterations, // stage 2 fails. - for (attempt <- 0 until Stage.MAX_CONSECUTIVE_FETCH_FAILURES) { + for (attempt <- 0 until scheduler.maxConsecutiveStageAttempts) { // Complete all the tasks for the current attempt of stage 0 successfully completeShuffleMapStageSuccessfully(0, attempt, numShufflePartitions = 2) - if (attempt < Stage.MAX_CONSECUTIVE_FETCH_FAILURES / 2) { + if (attempt < scheduler.maxConsecutiveStageAttempts / 2) { // Now we should have a new taskSet, for a new attempt of stage 1. // Fail all these tasks with FetchFailure completeNextStageWithFetchFailure(1, attempt, shuffleDepOne) @@ -859,8 +859,8 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou completeShuffleMapStageSuccessfully(1, attempt, numShufflePartitions = 1) // Fail stage 2 - completeNextStageWithFetchFailure(2, attempt - Stage.MAX_CONSECUTIVE_FETCH_FAILURES / 2, - shuffleDepTwo) + completeNextStageWithFetchFailure(2, + attempt - scheduler.maxConsecutiveStageAttempts / 2, shuffleDepTwo) } // this will trigger a resubmission of stage 0, since we've lost some of its @@ -872,7 +872,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou completeShuffleMapStageSuccessfully(1, 4, numShufflePartitions = 1) // Succeed stage2 with a "42" - completeNextResultStageWithSuccess(2, Stage.MAX_CONSECUTIVE_FETCH_FAILURES/2) + completeNextResultStageWithSuccess(2, scheduler.maxConsecutiveStageAttempts / 2) assert(results === Map(0 -> 42)) assertDataStructuresEmpty() @@ -895,7 +895,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou submit(finalRdd, Array(0)) // First, execute stages 0 and 1, failing stage 1 up to MAX-1 times. - for (attempt <- 0 until Stage.MAX_CONSECUTIVE_FETCH_FAILURES - 1) { + for (attempt <- 0 until scheduler.maxConsecutiveStageAttempts - 1) { // Make each task in stage 0 success completeShuffleMapStageSuccessfully(0, attempt, numShufflePartitions = 2) diff --git a/docs/configuration.md b/docs/configuration.md index 63392a741a1f..4729f1b0404c 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -1506,6 +1506,11 @@ Apart from these, the following properties are also available, and may be useful of this setting is to act as a safety-net to prevent runaway uncancellable tasks from rendering an executor unusable. + spark.stage.maxConsecutiveAttempts + 4 + + Number of consecutive stage attempts allowed before a stage is aborted. + From 376d782164437573880f0ad58cecae1cb5f212f2 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Fri, 17 Mar 2017 11:12:23 -0700 Subject: [PATCH 058/512] [SPARK-19986][TESTS] Make pyspark.streaming.tests.CheckpointTests more stable ## What changes were proposed in this pull request? Sometimes, CheckpointTests will hang on a busy machine because the streaming jobs are too slow and cannot catch up. I observed the scheduled delay was keeping increasing for dozens of seconds locally. This PR increases the batch interval from 0.5 seconds to 2 seconds to generate less Spark jobs. It should make `pyspark.streaming.tests.CheckpointTests` more stable. I also replaced `sleep` with `awaitTerminationOrTimeout` so that if the streaming job fails, it will also fail the test. ## How was this patch tested? Jenkins Author: Shixiong Zhu Closes #17323 from zsxwing/SPARK-19986. --- python/pyspark/streaming/tests.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 2e8ed698278d..1bec33509580 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -903,11 +903,11 @@ def updater(vs, s): def setup(): conf = SparkConf().set("spark.default.parallelism", 1) sc = SparkContext(conf=conf) - ssc = StreamingContext(sc, 0.5) + ssc = StreamingContext(sc, 2) dstream = ssc.textFileStream(inputd).map(lambda x: (x, 1)) wc = dstream.updateStateByKey(updater) wc.map(lambda x: "%s,%d" % x).saveAsTextFiles(outputd + "test") - wc.checkpoint(.5) + wc.checkpoint(2) self.setupCalled = True return ssc @@ -921,21 +921,22 @@ def setup(): def check_output(n): while not os.listdir(outputd): - time.sleep(0.01) + if self.ssc.awaitTerminationOrTimeout(0.5): + raise Exception("ssc stopped") time.sleep(1) # make sure mtime is larger than the previous one with open(os.path.join(inputd, str(n)), 'w') as f: f.writelines(["%d\n" % i for i in range(10)]) while True: + if self.ssc.awaitTerminationOrTimeout(0.5): + raise Exception("ssc stopped") p = os.path.join(outputd, max(os.listdir(outputd))) if '_SUCCESS' not in os.listdir(p): # not finished - time.sleep(0.01) continue ordd = self.ssc.sparkContext.textFile(p).map(lambda line: line.split(",")) d = ordd.values().map(int).collect() if not d: - time.sleep(0.01) continue self.assertEqual(10, len(d)) s = set(d) From bfdeea5c68f963ce60d48d0aa4a4c8c582169950 Mon Sep 17 00:00:00 2001 From: Andrew Ray Date: Fri, 17 Mar 2017 14:23:07 -0700 Subject: [PATCH 059/512] [SPARK-18847][GRAPHX] PageRank gives incorrect results for graphs with sinks ## What changes were proposed in this pull request? Graphs with sinks (vertices with no outgoing edges) don't have the expected rank sum of n (or 1 for personalized). We fix this by normalizing to the expected sum at the end of each implementation. Additionally this fixes the dynamic version of personal pagerank which gave incorrect answers that were not detected by existing unit tests. ## How was this patch tested? Revamped existing and additional unit tests with reference values (and reproduction code) from igraph and NetworkX. Note that for comparison on personal pagerank we use the arpack algorithm in igraph as prpack (the current default) redistributes rank to all vertices uniformly instead of just to the personalization source. We could take the alternate convention (redistribute rank to all vertices uniformly) but that would involve more extensive changes to the algorithms (the dynamic version would no longer be able to use Pregel). Author: Andrew Ray Closes #16483 from aray/pagerank-sink2. --- .../apache/spark/graphx/lib/PageRank.scala | 45 +++-- .../spark/graphx/lib/PageRankSuite.scala | 158 +++++++++++++----- 2 files changed, 144 insertions(+), 59 deletions(-) diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala index 37b6e453592e..13b2b5771918 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala @@ -162,7 +162,8 @@ object PageRank extends Logging { iteration += 1 } - rankGraph + // SPARK-18847 If the graph has sinks (vertices with no outgoing edges) correct the sum of ranks + normalizeRankSum(rankGraph, personalized) } /** @@ -179,7 +180,8 @@ object PageRank extends Logging { * @param resetProb The random reset probability * @param sources The list of sources to compute personalized pagerank from * @return the graph with vertex attributes - * containing the pagerank relative to all starting nodes (as a sparse vector) and + * containing the pagerank relative to all starting nodes (as a sparse vector + * indexed by the position of nodes in the sources list) and * edge attributes the normalized edge weight */ def runParallelPersonalizedPageRank[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED], @@ -194,6 +196,8 @@ object PageRank extends Logging { // TODO if one sources vertex id is outside of the int range // we won't be able to store its activations in a sparse vector + require(sources.max <= Int.MaxValue.toLong, + s"This implementation currently only works for source vertex ids at most ${Int.MaxValue}") val zero = Vectors.sparse(sources.size, List()).asBreeze val sourcesInitMap = sources.zipWithIndex.map { case (vid, i) => val v = Vectors.sparse(sources.size, Array(i), Array(1.0)).asBreeze @@ -245,8 +249,10 @@ object PageRank extends Logging { i += 1 } + // SPARK-18847 If the graph has sinks (vertices with no outgoing edges) correct the sum of ranks + val rankSums = rankGraph.vertices.values.fold(zero)(_ :+ _) rankGraph.mapVertices { (vid, attr) => - Vectors.fromBreeze(attr) + Vectors.fromBreeze(attr :/ rankSums) } } @@ -307,7 +313,7 @@ object PageRank extends Logging { .mapTriplets( e => 1.0 / e.srcAttr ) // Set the vertex attributes to (initialPR, delta = 0) .mapVertices { (id, attr) => - if (id == src) (1.0, Double.NegativeInfinity) else (0.0, 0.0) + if (id == src) (0.0, Double.NegativeInfinity) else (0.0, 0.0) } .cache() @@ -322,13 +328,12 @@ object PageRank extends Logging { def personalizedVertexProgram(id: VertexId, attr: (Double, Double), msgSum: Double): (Double, Double) = { val (oldPR, lastDelta) = attr - var teleport = oldPR - val delta = if (src==id) resetProb else 0.0 - teleport = oldPR*delta - - val newPR = teleport + (1.0 - resetProb) * msgSum - val newDelta = if (lastDelta == Double.NegativeInfinity) newPR else newPR - oldPR - (newPR, newDelta) + val newPR = if (lastDelta == Double.NegativeInfinity) { + 1.0 + } else { + oldPR + (1.0 - resetProb) * msgSum + } + (newPR, newPR - oldPR) } def sendMessage(edge: EdgeTriplet[(Double, Double), Double]) = { @@ -353,9 +358,23 @@ object PageRank extends Logging { vertexProgram(id, attr, msgSum) } - Pregel(pagerankGraph, initialMessage, activeDirection = EdgeDirection.Out)( + val rankGraph = Pregel(pagerankGraph, initialMessage, activeDirection = EdgeDirection.Out)( vp, sendMessage, messageCombiner) .mapVertices((vid, attr) => attr._1) - } // end of deltaPageRank + // SPARK-18847 If the graph has sinks (vertices with no outgoing edges) correct the sum of ranks + normalizeRankSum(rankGraph, personalized) + } + + // Normalizes the sum of ranks to n (or 1 if personalized) + private def normalizeRankSum(rankGraph: Graph[Double, Double], personalized: Boolean) = { + val rankSum = rankGraph.vertices.values.sum() + if (personalized) { + rankGraph.mapVertices((id, rank) => rank / rankSum) + } else { + val numVertices = rankGraph.numVertices + val correctionFactor = numVertices.toDouble / rankSum + rankGraph.mapVertices((id, rank) => rank * correctionFactor) + } + } } diff --git a/graphx/src/test/scala/org/apache/spark/graphx/lib/PageRankSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/lib/PageRankSuite.scala index 6afbb5a95989..9779553ce85d 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/lib/PageRankSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/lib/PageRankSuite.scala @@ -50,7 +50,8 @@ object GridPageRank { inNbrs(ind).map( nbr => oldPr(nbr) / outDegree(nbr)).sum } } - (0L until (nRows * nCols)).zip(pr) + val prSum = pr.sum + (0L until (nRows * nCols)).zip(pr.map(_ * pr.length / prSum)) } } @@ -68,26 +69,34 @@ class PageRankSuite extends SparkFunSuite with LocalSparkContext { val nVertices = 100 val starGraph = GraphGenerators.starGraph(sc, nVertices).cache() val resetProb = 0.15 + val tol = 0.0001 + val numIter = 2 val errorTol = 1.0e-5 - val staticRanks1 = starGraph.staticPageRank(numIter = 2, resetProb).vertices - val staticRanks2 = starGraph.staticPageRank(numIter = 3, resetProb).vertices.cache() + val staticRanks = starGraph.staticPageRank(numIter, resetProb).vertices.cache() + val staticRanks2 = starGraph.staticPageRank(numIter + 1, resetProb).vertices - // Static PageRank should only take 3 iterations to converge - val notMatching = staticRanks1.innerZipJoin(staticRanks2) { (vid, pr1, pr2) => + // Static PageRank should only take 2 iterations to converge + val notMatching = staticRanks.innerZipJoin(staticRanks2) { (vid, pr1, pr2) => if (pr1 != pr2) 1 else 0 }.map { case (vid, test) => test }.sum() assert(notMatching === 0) - val staticErrors = staticRanks2.map { case (vid, pr) => - val p = math.abs(pr - (resetProb + (1.0 - resetProb) * (resetProb * (nVertices - 1)) )) - val correct = (vid > 0 && pr == resetProb) || (vid == 0L && p < 1.0E-5) - if (!correct) 1 else 0 - } - assert(staticErrors.sum === 0) + val dynamicRanks = starGraph.pageRank(tol, resetProb).vertices.cache() + assert(compareRanks(staticRanks, dynamicRanks) < errorTol) + + // Computed in igraph 1.0 w/ R bindings: + // > page_rank(make_star(100, mode = "in")) + // Alternatively in NetworkX 1.11: + // > nx.pagerank(nx.DiGraph([(x, 0) for x in range(1,100)])) + // We multiply by the number of vertices to account for difference in normalization + val centerRank = 0.462394787 * nVertices + val othersRank = 0.005430356 * nVertices + val igraphPR = centerRank +: Seq.fill(nVertices - 1)(othersRank) + val ranks = VertexRDD(sc.parallelize(0L until nVertices zip igraphPR)) + assert(compareRanks(staticRanks, ranks) < errorTol) + assert(compareRanks(dynamicRanks, ranks) < errorTol) - val dynamicRanks = starGraph.pageRank(0, resetProb).vertices.cache() - assert(compareRanks(staticRanks2, dynamicRanks) < errorTol) } } // end of test Star PageRank @@ -96,51 +105,62 @@ class PageRankSuite extends SparkFunSuite with LocalSparkContext { val nVertices = 100 val starGraph = GraphGenerators.starGraph(sc, nVertices).cache() val resetProb = 0.15 + val tol = 0.0001 + val numIter = 2 val errorTol = 1.0e-5 - val staticRanks1 = starGraph.staticPersonalizedPageRank(0, numIter = 1, resetProb).vertices - val staticRanks2 = starGraph.staticPersonalizedPageRank(0, numIter = 2, resetProb) - .vertices.cache() - - // Static PageRank should only take 2 iterations to converge - val notMatching = staticRanks1.innerZipJoin(staticRanks2) { (vid, pr1, pr2) => - if (pr1 != pr2) 1 else 0 - }.map { case (vid, test) => test }.sum - assert(notMatching === 0) + val staticRanks = starGraph.staticPersonalizedPageRank(0, numIter, resetProb).vertices.cache() - val staticErrors = staticRanks2.map { case (vid, pr) => - val correct = (vid > 0 && pr == 0.0) || - (vid == 0 && pr == resetProb) - if (!correct) 1 else 0 - } - assert(staticErrors.sum === 0) - - val dynamicRanks = starGraph.personalizedPageRank(0, 0, resetProb).vertices.cache() - assert(compareRanks(staticRanks2, dynamicRanks) < errorTol) + val dynamicRanks = starGraph.personalizedPageRank(0, tol, resetProb).vertices.cache() + assert(compareRanks(staticRanks, dynamicRanks) < errorTol) - val parallelStaticRanks1 = starGraph - .staticParallelPersonalizedPageRank(Array(0), 1, resetProb).mapVertices { + val parallelStaticRanks = starGraph + .staticParallelPersonalizedPageRank(Array(0), numIter, resetProb).mapVertices { case (vertexId, vector) => vector(0) }.vertices.cache() - assert(compareRanks(staticRanks1, parallelStaticRanks1) < errorTol) + assert(compareRanks(staticRanks, parallelStaticRanks) < errorTol) + + // Computed in igraph 1.0 w/ R bindings: + // > page_rank(make_star(100, mode = "in"), personalized = c(1, rep(0, 99)), algo = "arpack") + // NOTE: We use the arpack algorithm as prpack (the default) redistributes rank to all + // vertices uniformly instead of just to the personalization source. + // Alternatively in NetworkX 1.11: + // > nx.pagerank(nx.DiGraph([(x, 0) for x in range(1,100)]), + // personalization=dict([(x, 1 if x == 0 else 0) for x in range(0,100)])) + // We multiply by the number of vertices to account for difference in normalization + val igraphPR0 = 1.0 +: Seq.fill(nVertices - 1)(0.0) + val ranks0 = VertexRDD(sc.parallelize(0L until nVertices zip igraphPR0)) + assert(compareRanks(staticRanks, ranks0) < errorTol) + assert(compareRanks(dynamicRanks, ranks0) < errorTol) - val parallelStaticRanks2 = starGraph - .staticParallelPersonalizedPageRank(Array(0, 1), 2, resetProb).mapVertices { - case (vertexId, vector) => vector(0) - }.vertices.cache() - assert(compareRanks(staticRanks2, parallelStaticRanks2) < errorTol) // We have one outbound edge from 1 to 0 - val otherStaticRanks2 = starGraph.staticPersonalizedPageRank(1, numIter = 2, resetProb) + val otherStaticRanks = starGraph.staticPersonalizedPageRank(1, numIter, resetProb) .vertices.cache() - val otherDynamicRanks = starGraph.personalizedPageRank(1, 0, resetProb).vertices.cache() - val otherParallelStaticRanks2 = starGraph - .staticParallelPersonalizedPageRank(Array(0, 1), 2, resetProb).mapVertices { + val otherDynamicRanks = starGraph.personalizedPageRank(1, tol, resetProb).vertices.cache() + val otherParallelStaticRanks = starGraph + .staticParallelPersonalizedPageRank(Array(0, 1), numIter, resetProb).mapVertices { case (vertexId, vector) => vector(1) }.vertices.cache() - assert(compareRanks(otherDynamicRanks, otherStaticRanks2) < errorTol) - assert(compareRanks(otherStaticRanks2, otherParallelStaticRanks2) < errorTol) - assert(compareRanks(otherDynamicRanks, otherParallelStaticRanks2) < errorTol) + assert(compareRanks(otherDynamicRanks, otherStaticRanks) < errorTol) + assert(compareRanks(otherStaticRanks, otherParallelStaticRanks) < errorTol) + assert(compareRanks(otherDynamicRanks, otherParallelStaticRanks) < errorTol) + + // Computed in igraph 1.0 w/ R bindings: + // > page_rank(make_star(100, mode = "in"), + // personalized = c(0, 1, rep(0, 98)), algo = "arpack") + // NOTE: We use the arpack algorithm as prpack (the default) redistributes rank to all + // vertices uniformly instead of just to the personalization source. + // Alternatively in NetworkX 1.11: + // > nx.pagerank(nx.DiGraph([(x, 0) for x in range(1,100)]), + // personalization=dict([(x, 1 if x == 1 else 0) for x in range(0,100)])) + val centerRank = 0.4594595 + val sourceRank = 0.5405405 + val igraphPR1 = centerRank +: sourceRank +: Seq.fill(nVertices - 2)(0.0) + val ranks1 = VertexRDD(sc.parallelize(0L until nVertices zip igraphPR1)) + assert(compareRanks(otherStaticRanks, ranks1) < errorTol) + assert(compareRanks(otherDynamicRanks, ranks1) < errorTol) + assert(compareRanks(otherParallelStaticRanks, ranks1) < errorTol) } } // end of test Star PersonalPageRank @@ -229,4 +249,50 @@ class PageRankSuite extends SparkFunSuite with LocalSparkContext { } } + + test("Loop with sink PageRank") { + withSpark { sc => + val edges = sc.parallelize((1L, 2L) :: (2L, 3L) :: (3L, 1L) :: (1L, 4L) :: Nil) + val g = Graph.fromEdgeTuples(edges, 1) + val resetProb = 0.15 + val tol = 0.0001 + val numIter = 20 + val errorTol = 1.0e-5 + + val staticRanks = g.staticPageRank(numIter, resetProb).vertices.cache() + val dynamicRanks = g.pageRank(tol, resetProb).vertices.cache() + + assert(compareRanks(staticRanks, dynamicRanks) < errorTol) + + // Computed in igraph 1.0 w/ R bindings: + // > page_rank(graph_from_literal( A -+ B -+ C -+ A -+ D)) + // Alternatively in NetworkX 1.11: + // > nx.pagerank(nx.DiGraph([(1,2),(2,3),(3,1),(1,4)])) + // We multiply by the number of vertices to account for difference in normalization + val igraphPR = Seq(0.3078534, 0.2137622, 0.2646223, 0.2137622).map(_ * 4) + val ranks = VertexRDD(sc.parallelize(1L to 4L zip igraphPR)) + assert(compareRanks(staticRanks, ranks) < errorTol) + assert(compareRanks(dynamicRanks, ranks) < errorTol) + + val p1staticRanks = g.staticPersonalizedPageRank(1, numIter, resetProb).vertices.cache() + val p1dynamicRanks = g.personalizedPageRank(1, tol, resetProb).vertices.cache() + val p1parallelDynamicRanks = + g.staticParallelPersonalizedPageRank(Array(1, 2, 3, 4), numIter, resetProb) + .vertices.mapValues(v => v(0)).cache() + + // Computed in igraph 1.0 w/ R bindings: + // > page_rank(graph_from_literal( A -+ B -+ C -+ A -+ D), personalized = c(1, 0, 0, 0), + // algo = "arpack") + // NOTE: We use the arpack algorithm as prpack (the default) redistributes rank to all + // vertices uniformly instead of just to the personalization source. + // Alternatively in NetworkX 1.11: + // > nx.pagerank(nx.DiGraph([(1,2),(2,3),(3,1),(1,4)]), personalization={1:1, 2:0, 3:0, 4:0}) + val igraphPR2 = Seq(0.4522329, 0.1921990, 0.1633691, 0.1921990) + val ranks2 = VertexRDD(sc.parallelize(1L to 4L zip igraphPR2)) + assert(compareRanks(p1staticRanks, ranks2) < errorTol) + assert(compareRanks(p1dynamicRanks, ranks2) < errorTol) + assert(compareRanks(p1parallelDynamicRanks, ranks2) < errorTol) + + } + } } From 7de66bae58733595cb88ec899640f7acf734d5c4 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Fri, 17 Mar 2017 14:51:59 -0700 Subject: [PATCH 060/512] [SPARK-19967][SQL] Add from_json in FunctionRegistry ## What changes were proposed in this pull request? This pr added entries in `FunctionRegistry` and supported `from_json` in SQL. ## How was this patch tested? Added tests in `JsonFunctionsSuite` and `SQLQueryTestSuite`. Author: Takeshi Yamamuro Closes #17320 from maropu/SPARK-19967. --- .../catalyst/analysis/FunctionRegistry.scala | 1 + .../expressions/jsonExpressions.scala | 36 +++++- .../sql-tests/inputs/json-functions.sql | 13 +++ .../sql-tests/results/json-functions.sql.out | 107 +++++++++++++++++- .../apache/spark/sql/JsonFunctionsSuite.scala | 36 ++++++ 5 files changed, 189 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 0dcb44081f60..0486e67dbdf8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -426,6 +426,7 @@ object FunctionRegistry { // json expression[StructToJson]("to_json"), + expression[JsonToStruct]("from_json"), // Cast aliases (SPARK-16730) castAlias("boolean", BooleanType), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index 18b5f2f7ed2e..37e4bb506043 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -26,6 +26,7 @@ import com.fasterxml.jackson.core._ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.json._ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, ParseModes} @@ -483,6 +484,17 @@ case class JsonTuple(children: Seq[Expression]) /** * Converts an json input string to a [[StructType]] or [[ArrayType]] with the specified schema. */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(jsonStr, schema[, options]) - Returns a struct value with the given `jsonStr` and `schema`.", + extended = """ + Examples: + > SELECT _FUNC_('{"a":1, "b":0.8}', 'a INT, b DOUBLE'); + {"a":1, "b":0.8} + > SELECT _FUNC_('{"time":"26/08/2015"}', 'time Timestamp', map('timestampFormat', 'dd/MM/yyyy')); + {"time":"2015-08-26 00:00:00.0"} + """) +// scalastyle:on line.size.limit case class JsonToStruct( schema: DataType, options: Map[String, String], @@ -494,6 +506,21 @@ case class JsonToStruct( def this(schema: DataType, options: Map[String, String], child: Expression) = this(schema, options, child, None) + // Used in `FunctionRegistry` + def this(child: Expression, schema: Expression) = + this( + schema = JsonExprUtils.validateSchemaLiteral(schema), + options = Map.empty[String, String], + child = child, + timeZoneId = None) + + def this(child: Expression, schema: Expression, options: Expression) = + this( + schema = JsonExprUtils.validateSchemaLiteral(schema), + options = JsonExprUtils.convertToMapData(options), + child = child, + timeZoneId = None) + override def checkInputDataTypes(): TypeCheckResult = schema match { case _: StructType | ArrayType(_: StructType, _) => super.checkInputDataTypes() @@ -589,7 +616,7 @@ case class StructToJson( def this(child: Expression) = this(Map.empty, child, None) def this(child: Expression, options: Expression) = this( - options = StructToJson.convertToMapData(options), + options = JsonExprUtils.convertToMapData(options), child = child, timeZoneId = None) @@ -634,7 +661,12 @@ case class StructToJson( override def inputTypes: Seq[AbstractDataType] = StructType :: Nil } -object StructToJson { +object JsonExprUtils { + + def validateSchemaLiteral(exp: Expression): StructType = exp match { + case Literal(s, StringType) => CatalystSqlParser.parseTableSchema(s.toString) + case e => throw new AnalysisException(s"Expected a string literal instead of $e") + } def convertToMapData(exp: Expression): Map[String, String] = exp match { case m: CreateMap diff --git a/sql/core/src/test/resources/sql-tests/inputs/json-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/json-functions.sql index 9308560451bf..83243c5e5a12 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/json-functions.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/json-functions.sql @@ -5,4 +5,17 @@ select to_json(named_struct('a', 1, 'b', 2)); select to_json(named_struct('time', to_timestamp('2015-08-26', 'yyyy-MM-dd')), map('timestampFormat', 'dd/MM/yyyy')); -- Check if errors handled select to_json(named_struct('a', 1, 'b', 2), named_struct('mode', 'PERMISSIVE')); +select to_json(named_struct('a', 1, 'b', 2), map('mode', 1)); select to_json(); + +-- from_json +describe function from_json; +describe function extended from_json; +select from_json('{"a":1}', 'a INT'); +select from_json('{"time":"26/08/2015"}', 'time Timestamp', map('timestampFormat', 'dd/MM/yyyy')); +-- Check if errors handled +select from_json('{"a":1}', 1); +select from_json('{"a":1}', 'a InvalidType'); +select from_json('{"a":1}', 'a INT', named_struct('mode', 'PERMISSIVE')); +select from_json('{"a":1}', 'a INT', map('mode', 1)); +select from_json(); diff --git a/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out index d8aa4fb9fa78..b57cbbc1d843 100644 --- a/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 6 +-- Number of queries: 16 -- !query 0 @@ -55,9 +55,112 @@ Must use a map() function for options;; line 1 pos 7 -- !query 5 -select to_json() +select to_json(named_struct('a', 1, 'b', 2), map('mode', 1)) -- !query 5 schema struct<> -- !query 5 output org.apache.spark.sql.AnalysisException +A type of keys and values in map() must be string, but got MapType(StringType,IntegerType,false);; line 1 pos 7 + + +-- !query 6 +select to_json() +-- !query 6 schema +struct<> +-- !query 6 output +org.apache.spark.sql.AnalysisException Invalid number of arguments for function to_json; line 1 pos 7 + + +-- !query 7 +describe function from_json +-- !query 7 schema +struct +-- !query 7 output +Class: org.apache.spark.sql.catalyst.expressions.JsonToStruct +Function: from_json +Usage: from_json(jsonStr, schema[, options]) - Returns a struct value with the given `jsonStr` and `schema`. + + +-- !query 8 +describe function extended from_json +-- !query 8 schema +struct +-- !query 8 output +Class: org.apache.spark.sql.catalyst.expressions.JsonToStruct +Extended Usage: + Examples: + > SELECT from_json('{"a":1, "b":0.8}', 'a INT, b DOUBLE'); + {"a":1, "b":0.8} + > SELECT from_json('{"time":"26/08/2015"}', 'time Timestamp', map('timestampFormat', 'dd/MM/yyyy')); + {"time":"2015-08-26 00:00:00.0"} + +Function: from_json +Usage: from_json(jsonStr, schema[, options]) - Returns a struct value with the given `jsonStr` and `schema`. + + +-- !query 9 +select from_json('{"a":1}', 'a INT') +-- !query 9 schema +struct> +-- !query 9 output +{"a":1} + + +-- !query 10 +select from_json('{"time":"26/08/2015"}', 'time Timestamp', map('timestampFormat', 'dd/MM/yyyy')) +-- !query 10 schema +struct> +-- !query 10 output +{"time":2015-08-26 00:00:00.0} + + +-- !query 11 +select from_json('{"a":1}', 1) +-- !query 11 schema +struct<> +-- !query 11 output +org.apache.spark.sql.AnalysisException +Expected a string literal instead of 1;; line 1 pos 7 + + +-- !query 12 +select from_json('{"a":1}', 'a InvalidType') +-- !query 12 schema +struct<> +-- !query 12 output +org.apache.spark.sql.AnalysisException + +DataType invalidtype() is not supported.(line 1, pos 2) + +== SQL == +a InvalidType +--^^^ +; line 1 pos 7 + + +-- !query 13 +select from_json('{"a":1}', 'a INT', named_struct('mode', 'PERMISSIVE')) +-- !query 13 schema +struct<> +-- !query 13 output +org.apache.spark.sql.AnalysisException +Must use a map() function for options;; line 1 pos 7 + + +-- !query 14 +select from_json('{"a":1}', 'a INT', map('mode', 1)) +-- !query 14 schema +struct<> +-- !query 14 output +org.apache.spark.sql.AnalysisException +A type of keys and values in map() must be string, but got MapType(StringType,IntegerType,false);; line 1 pos 7 + + +-- !query 15 +select from_json() +-- !query 15 schema +struct<> +-- !query 15 output +org.apache.spark.sql.AnalysisException +Invalid number of arguments for function from_json; line 1 pos 7 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala index cdea3b9a0f79..2345b8208116 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala @@ -220,4 +220,40 @@ class JsonFunctionsSuite extends QueryTest with SharedSQLContext { assert(errMsg2.getMessage.startsWith( "A type of keys and values in map() must be string, but got")) } + + test("SPARK-19967 Support from_json in SQL") { + val df1 = Seq("""{"a": 1}""").toDS() + checkAnswer( + df1.selectExpr("from_json(value, 'a INT')"), + Row(Row(1)) :: Nil) + + val df2 = Seq("""{"c0": "a", "c1": 1, "c2": {"c20": 3.8, "c21": 8}}""").toDS() + checkAnswer( + df2.selectExpr("from_json(value, 'c0 STRING, c1 INT, c2 STRUCT')"), + Row(Row("a", 1, Row(3.8, 8))) :: Nil) + + val df3 = Seq("""{"time": "26/08/2015 18:00"}""").toDS() + checkAnswer( + df3.selectExpr( + "from_json(value, 'time Timestamp', map('timestampFormat', 'dd/MM/yyyy HH:mm'))"), + Row(Row(java.sql.Timestamp.valueOf("2015-08-26 18:00:00.0")))) + + val errMsg1 = intercept[AnalysisException] { + df3.selectExpr("from_json(value, 1)") + } + assert(errMsg1.getMessage.startsWith("Expected a string literal instead of")) + val errMsg2 = intercept[AnalysisException] { + df3.selectExpr("""from_json(value, 'time InvalidType')""") + } + assert(errMsg2.getMessage.contains("DataType invalidtype() is not supported")) + val errMsg3 = intercept[AnalysisException] { + df3.selectExpr("from_json(value, 'time Timestamp', named_struct('a', 1))") + } + assert(errMsg3.getMessage.startsWith("Must use a map() function for options")) + val errMsg4 = intercept[AnalysisException] { + df3.selectExpr("from_json(value, 'time Timestamp', map('a', 1))") + } + assert(errMsg4.getMessage.startsWith( + "A type of keys and values in map() must be string, but got")) + } } From 3783539d7ab83a2a632a9f35ca66ae39d01c28b6 Mon Sep 17 00:00:00 2001 From: Kunal Khamar Date: Fri, 17 Mar 2017 16:16:22 -0700 Subject: [PATCH 061/512] [SPARK-19873][SS] Record num shuffle partitions in offset log and enforce in next batch. ## What changes were proposed in this pull request? If the user changes the shuffle partition number between batches, Streaming aggregation will fail. Here are some possible cases: - Change "spark.sql.shuffle.partitions" - Use "repartition" and change the partition number in codes - RangePartitioner doesn't generate deterministic partitions. Right now it's safe as we disallow sort before aggregation. Not sure if we will add some operators using RangePartitioner in future. ## How was this patch tested? - Unit tests - Manual tests - forward compatibility tested by using the new `OffsetSeqMetadata` json with Spark v2.1.0 Author: Kunal Khamar Closes #17216 from kunalkhamar/num-partitions. --- .../sql/execution/streaming/OffsetSeq.scala | 8 +- .../execution/streaming/StreamExecution.scala | 60 ++++++++--- .../sql/streaming/StreamingQueryManager.scala | 8 +- .../checkpoint-version-2.1.0/metadata | 1 + .../checkpoint-version-2.1.0/offsets/0 | 3 + .../checkpoint-version-2.1.0/offsets/1 | 3 + .../state/0/0/1.delta | Bin 0 -> 46 bytes .../state/0/0/2.delta | Bin 0 -> 46 bytes .../state/0/1/1.delta | Bin 0 -> 79 bytes .../state/0/1/2.delta | Bin 0 -> 79 bytes .../state/0/2/1.delta | Bin 0 -> 79 bytes .../state/0/2/2.delta | Bin 0 -> 79 bytes .../state/0/3/1.delta | Bin 0 -> 73 bytes .../state/0/3/2.delta | Bin 0 -> 79 bytes .../state/0/4/1.delta | Bin 0 -> 79 bytes .../state/0/4/2.delta | Bin 0 -> 46 bytes .../state/0/5/1.delta | Bin 0 -> 46 bytes .../state/0/5/2.delta | Bin 0 -> 46 bytes .../state/0/6/1.delta | Bin 0 -> 46 bytes .../state/0/6/2.delta | Bin 0 -> 79 bytes .../state/0/7/1.delta | Bin 0 -> 46 bytes .../state/0/7/2.delta | Bin 0 -> 79 bytes .../state/0/8/1.delta | Bin 0 -> 46 bytes .../state/0/8/2.delta | Bin 0 -> 46 bytes .../state/0/9/1.delta | Bin 0 -> 46 bytes .../state/0/9/2.delta | Bin 0 -> 79 bytes .../streaming/OffsetSeqLogSuite.scala | 38 +++++-- .../spark/sql/streaming/StreamSuite.scala | 101 +++++++++++++++++- .../StreamingQueryManagerSuite.scala | 10 -- .../test/DataStreamReaderWriterSuite.scala | 22 ++-- 30 files changed, 207 insertions(+), 47 deletions(-) create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/metadata create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/offsets/0 create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/offsets/1 create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/state/0/0/1.delta create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/state/0/0/2.delta create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/state/0/1/1.delta create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/state/0/1/2.delta create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/state/0/2/1.delta create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/state/0/2/2.delta create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/state/0/3/1.delta create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/state/0/3/2.delta create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/state/0/4/1.delta create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/state/0/4/2.delta create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/state/0/5/1.delta create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/state/0/5/2.delta create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/state/0/6/1.delta create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/state/0/6/2.delta create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/state/0/7/1.delta create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/state/0/7/2.delta create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/state/0/8/1.delta create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/state/0/8/2.delta create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/state/0/9/1.delta create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/state/0/9/2.delta diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala index e5a1997d6b80..8249adab4bba 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql.execution.streaming import org.json4s.NoTypeHints import org.json4s.jackson.Serialization - /** * An ordered collection of offsets, used to track the progress of processing data from one or more * [[Source]]s that are present in a streaming query. This is similar to simplified, single-instance @@ -70,8 +69,12 @@ object OffsetSeq { * bound the lateness of data that will processed. Time unit: milliseconds * @param batchTimestampMs: The current batch processing timestamp. * Time unit: milliseconds + * @param conf: Additional conf_s to be persisted across batches, e.g. number of shuffle partitions. */ -case class OffsetSeqMetadata(var batchWatermarkMs: Long = 0, var batchTimestampMs: Long = 0) { +case class OffsetSeqMetadata( + batchWatermarkMs: Long = 0, + batchTimestampMs: Long = 0, + conf: Map[String, String] = Map.empty) { def json: String = Serialization.write(this)(OffsetSeqMetadata.format) } @@ -79,4 +82,3 @@ object OffsetSeqMetadata { private implicit val format = Serialization.formats(NoTypeHints) def apply(json: String): OffsetSeqMetadata = Serialization.read[OffsetSeqMetadata](json) } - diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index 529263805c0a..40faddccc242 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -35,6 +35,7 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, Curre import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.execution.command.StreamingExplainCommand +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming._ import org.apache.spark.util.{Clock, UninterruptibleThread, Utils} @@ -117,7 +118,9 @@ class StreamExecution( } /** Metadata associated with the offset seq of a batch in the query. */ - protected var offsetSeqMetadata = OffsetSeqMetadata() + protected var offsetSeqMetadata = OffsetSeqMetadata(batchWatermarkMs = 0, batchTimestampMs = 0, + conf = Map(SQLConf.SHUFFLE_PARTITIONS.key -> + sparkSession.conf.get(SQLConf.SHUFFLE_PARTITIONS).toString)) override val id: UUID = UUID.fromString(streamMetadata.id) @@ -256,6 +259,15 @@ class StreamExecution( updateStatusMessage("Initializing sources") // force initialization of the logical plan so that the sources can be created logicalPlan + + // Isolated spark session to run the batches with. + val sparkSessionToRunBatches = sparkSession.cloneSession() + // Adaptive execution can change num shuffle partitions, disallow + sparkSessionToRunBatches.conf.set(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key, "false") + offsetSeqMetadata = OffsetSeqMetadata(batchWatermarkMs = 0, batchTimestampMs = 0, + conf = Map(SQLConf.SHUFFLE_PARTITIONS.key -> + sparkSessionToRunBatches.conf.get(SQLConf.SHUFFLE_PARTITIONS.key))) + if (state.compareAndSet(INITIALIZING, ACTIVE)) { // Unblock `awaitInitialization` initializationLatch.countDown() @@ -268,7 +280,7 @@ class StreamExecution( reportTimeTaken("triggerExecution") { if (currentBatchId < 0) { // We'll do this initialization only once - populateStartOffsets() + populateStartOffsets(sparkSessionToRunBatches) logDebug(s"Stream running from $committedOffsets to $availableOffsets") } else { constructNextBatch() @@ -276,7 +288,7 @@ class StreamExecution( if (dataAvailable) { currentStatus = currentStatus.copy(isDataAvailable = true) updateStatusMessage("Processing new data") - runBatch() + runBatch(sparkSessionToRunBatches) } } @@ -381,13 +393,32 @@ class StreamExecution( * - committedOffsets * - availableOffsets */ - private def populateStartOffsets(): Unit = { + private def populateStartOffsets(sparkSessionToRunBatches: SparkSession): Unit = { offsetLog.getLatest() match { case Some((batchId, nextOffsets)) => logInfo(s"Resuming streaming query, starting with batch $batchId") currentBatchId = batchId availableOffsets = nextOffsets.toStreamProgress(sources) - offsetSeqMetadata = nextOffsets.metadata.getOrElse(OffsetSeqMetadata()) + + // update offset metadata + nextOffsets.metadata.foreach { metadata => + val shufflePartitionsSparkSession: Int = + sparkSessionToRunBatches.conf.get(SQLConf.SHUFFLE_PARTITIONS) + val shufflePartitionsToUse = metadata.conf.getOrElse(SQLConf.SHUFFLE_PARTITIONS.key, { + // For backward compatibility, if # partitions was not recorded in the offset log, + // then ensure it is not missing. The new value is picked up from the conf. + logWarning("Number of shuffle partitions from previous run not found in checkpoint. " + + s"Using the value from the conf, $shufflePartitionsSparkSession partitions.") + shufflePartitionsSparkSession + }) + offsetSeqMetadata = OffsetSeqMetadata( + metadata.batchWatermarkMs, metadata.batchTimestampMs, + metadata.conf + (SQLConf.SHUFFLE_PARTITIONS.key -> shufflePartitionsToUse.toString)) + // Update conf with correct number of shuffle partitions + sparkSessionToRunBatches.conf.set( + SQLConf.SHUFFLE_PARTITIONS.key, shufflePartitionsToUse.toString) + } + logDebug(s"Found possibly unprocessed offsets $availableOffsets " + s"at batch timestamp ${offsetSeqMetadata.batchTimestampMs}") @@ -444,8 +475,7 @@ class StreamExecution( } } if (hasNewData) { - // Current batch timestamp in milliseconds - offsetSeqMetadata.batchTimestampMs = triggerClock.getTimeMillis() + var batchWatermarkMs = offsetSeqMetadata.batchWatermarkMs // Update the eventTime watermark if we find one in the plan. if (lastExecution != null) { lastExecution.executedPlan.collect { @@ -453,16 +483,19 @@ class StreamExecution( logDebug(s"Observed event time stats: ${e.eventTimeStats.value}") e.eventTimeStats.value.max - e.delayMs }.headOption.foreach { newWatermarkMs => - if (newWatermarkMs > offsetSeqMetadata.batchWatermarkMs) { + if (newWatermarkMs > batchWatermarkMs) { logInfo(s"Updating eventTime watermark to: $newWatermarkMs ms") - offsetSeqMetadata.batchWatermarkMs = newWatermarkMs + batchWatermarkMs = newWatermarkMs } else { logDebug( s"Event time didn't move: $newWatermarkMs < " + - s"${offsetSeqMetadata.batchWatermarkMs}") + s"$batchWatermarkMs") } } } + offsetSeqMetadata = offsetSeqMetadata.copy( + batchWatermarkMs = batchWatermarkMs, + batchTimestampMs = triggerClock.getTimeMillis()) // Current batch timestamp in milliseconds updateStatusMessage("Writing offsets to log") reportTimeTaken("walCommit") { @@ -505,8 +538,9 @@ class StreamExecution( /** * Processes any data available between `availableOffsets` and `committedOffsets`. + * @param sparkSessionToRunBatch Isolated [[SparkSession]] to run this batch with. */ - private def runBatch(): Unit = { + private def runBatch(sparkSessionToRunBatch: SparkSession): Unit = { // Request unprocessed data from all sources. newData = reportTimeTaken("getBatch") { availableOffsets.flatMap { @@ -551,7 +585,7 @@ class StreamExecution( reportTimeTaken("queryPlanning") { lastExecution = new IncrementalExecution( - sparkSession, + sparkSessionToRunBatch, triggerLogicalPlan, outputMode, checkpointFile("state"), @@ -561,7 +595,7 @@ class StreamExecution( } val nextBatch = - new Dataset(sparkSession, lastExecution, RowEncoder(lastExecution.analyzed.schema)) + new Dataset(sparkSessionToRunBatch, lastExecution, RowEncoder(lastExecution.analyzed.schema)) reportTimeTaken("addBatch") { sink.addBatch(currentBatchId, nextBatch) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala index 38edb40dfb78..7810d9f6e964 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala @@ -25,6 +25,7 @@ import scala.collection.mutable import org.apache.hadoop.fs.Path import org.apache.spark.annotation.{Experimental, InterfaceStability} +import org.apache.spark.internal.Logging import org.apache.spark.sql.{AnalysisException, DataFrame, SparkSession} import org.apache.spark.sql.catalyst.analysis.UnsupportedOperationChecker import org.apache.spark.sql.execution.streaming._ @@ -40,7 +41,7 @@ import org.apache.spark.util.{Clock, SystemClock, Utils} */ @Experimental @InterfaceStability.Evolving -class StreamingQueryManager private[sql] (sparkSession: SparkSession) { +class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Logging { private[sql] val stateStoreCoordinator = StateStoreCoordinatorRef.forDriver(sparkSession.sparkContext.env) @@ -234,9 +235,8 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) { } if (sparkSession.sessionState.conf.adaptiveExecutionEnabled) { - throw new AnalysisException( - s"${SQLConf.ADAPTIVE_EXECUTION_ENABLED.key} " + - "is not supported in streaming DataFrames/Datasets") + logWarning(s"${SQLConf.ADAPTIVE_EXECUTION_ENABLED.key} " + + "is not supported in streaming DataFrames/Datasets and will be disabled.") } new StreamingQueryWrapper(new StreamExecution( diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/metadata b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/metadata new file mode 100644 index 000000000000..3492220e36b8 --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/metadata @@ -0,0 +1 @@ +{"id":"dddc5e7f-1e71-454c-8362-de184444fb5a"} \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/offsets/0 b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/offsets/0 new file mode 100644 index 000000000000..cbde042e79af --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/offsets/0 @@ -0,0 +1,3 @@ +v1 +{"batchWatermarkMs":0,"batchTimestampMs":1489180207737} +0 \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/offsets/1 b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/offsets/1 new file mode 100644 index 000000000000..10b5774746de --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/offsets/1 @@ -0,0 +1,3 @@ +v1 +{"batchWatermarkMs":0,"batchTimestampMs":1489180209261} +2 \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/state/0/0/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/state/0/0/1.delta new file mode 100644 index 0000000000000000000000000000000000000000..6352978051846970ca41a0ca97fd79952105726d GIT binary patch literal 46 icmeZ?GI7euPtF!)VPIeY;oA+q9RGp92POd&g989JFAHe^ literal 0 HcmV?d00001 diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/state/0/0/2.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/state/0/0/2.delta new file mode 100644 index 0000000000000000000000000000000000000000..6352978051846970ca41a0ca97fd79952105726d GIT binary patch literal 46 icmeZ?GI7euPtF!)VPIeY;oA+q9RGp92POd&g989JFAHe^ literal 0 HcmV?d00001 diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/state/0/1/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/state/0/1/1.delta new file mode 100644 index 0000000000000000000000000000000000000000..7dc49cb3e47fd7a4001ff7ddc96094e754117c44 GIT binary patch literal 79 zcmeZ?GI7euPtI0VWnf^i0AjiN4z6GzEx^FYAk56c;0R>PuraWUFbFd8F)RS`fZ#t6 M_&{}vLWCeB023(<4FCWD literal 0 HcmV?d00001 diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/state/0/1/2.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/state/0/1/2.delta new file mode 100644 index 0000000000000000000000000000000000000000..8b566e81f48663efa0ebda2dbf694d65e28def72 GIT binary patch literal 79 zcmeZ?GI7euPtI0VWnf^i0OEK1WO;*uv;YGmgD^7(gCmeF!^Xfa!XU`R$FKm%1A_lR M-~-hu3K4>k06a_$wEzGB literal 0 HcmV?d00001 diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/state/0/2/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/state/0/2/1.delta new file mode 100644 index 0000000000000000000000000000000000000000..ca2a7ed033f3baf749f5a93522e951c15729e4c6 GIT binary patch literal 79 zcmeZ?GI7euPtI0VWnf^i0AjUmuFSzeT7ZF(L70Vu!4b%oVPjwyVGv~GV^{#>0l|MD M@PX0l|MD M@PXPuraWUFbFd8F)RS`fZ#t6 M_&{}vLWCeB00o>3)Bpeg literal 0 HcmV?d00001 diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/state/0/4/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/state/0/4/1.delta new file mode 100644 index 0000000000000000000000000000000000000000..fe521b8c07504adc9e174c9116497607365cca7e GIT binary patch literal 79 zcmeZ?GI7euPtI0VWnf^i0OE5h9UQ?xT7ZF(L70hy!4b%oVPjwyVGv~GV^{#>0l|MD M@PX0l|MD M@PXgCmeF!^Xfa!XU`V$FKm%1A_lR M-~-hu3K4>k06R$yw*UYD literal 0 HcmV?d00001 diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/state/0/8/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/state/0/8/1.delta new file mode 100644 index 0000000000000000000000000000000000000000..6352978051846970ca41a0ca97fd79952105726d GIT binary patch literal 46 icmeZ?GI7euPtF!)VPIeY;oA+q9RGp92POd&g989JFAHe^ literal 0 HcmV?d00001 diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/state/0/8/2.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/state/0/8/2.delta new file mode 100644 index 0000000000000000000000000000000000000000..6352978051846970ca41a0ca97fd79952105726d GIT binary patch literal 46 icmeZ?GI7euPtF!)VPIeY;oA+q9RGp92POd&g989JFAHe^ literal 0 HcmV?d00001 diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/state/0/9/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/state/0/9/1.delta new file mode 100644 index 0000000000000000000000000000000000000000..6352978051846970ca41a0ca97fd79952105726d GIT binary patch literal 46 icmeZ?GI7euPtF!)VPIeY;oA+q9RGp92POd&g989JFAHe^ literal 0 HcmV?d00001 diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/state/0/9/2.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/state/0/9/2.delta new file mode 100644 index 0000000000000000000000000000000000000000..0c9b6ac5c863d06d63c46c8a1fc51da716a5fdbd GIT binary patch literal 79 zcmeZ?GI7euPtI0VWnf^i0OIOpUzvh|v;YGmgD@KhgCmeF!^Xfa!XU`R$FKm%1A_lR M-~-hu3K4>k083;I`Tzg` literal 0 HcmV?d00001 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLogSuite.scala index f7f0dade8717..dc556322bedd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLogSuite.scala @@ -21,6 +21,7 @@ import java.io.File import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.util.stringToFile +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext class OffsetSeqLogSuite extends SparkFunSuite with SharedSQLContext { @@ -29,12 +30,37 @@ class OffsetSeqLogSuite extends SparkFunSuite with SharedSQLContext { case class StringOffset(override val json: String) extends Offset test("OffsetSeqMetadata - deserialization") { - assert(OffsetSeqMetadata(0, 0) === OffsetSeqMetadata("""{}""")) - assert(OffsetSeqMetadata(1, 0) === OffsetSeqMetadata("""{"batchWatermarkMs":1}""")) - assert(OffsetSeqMetadata(0, 2) === OffsetSeqMetadata("""{"batchTimestampMs":2}""")) - assert( - OffsetSeqMetadata(1, 2) === - OffsetSeqMetadata("""{"batchWatermarkMs":1,"batchTimestampMs":2}""")) + val key = SQLConf.SHUFFLE_PARTITIONS.key + + def getConfWith(shufflePartitions: Int): Map[String, String] = { + Map(key -> shufflePartitions.toString) + } + + // None set + assert(OffsetSeqMetadata(0, 0, Map.empty) === OffsetSeqMetadata("""{}""")) + + // One set + assert(OffsetSeqMetadata(1, 0, Map.empty) === OffsetSeqMetadata("""{"batchWatermarkMs":1}""")) + assert(OffsetSeqMetadata(0, 2, Map.empty) === OffsetSeqMetadata("""{"batchTimestampMs":2}""")) + assert(OffsetSeqMetadata(0, 0, getConfWith(shufflePartitions = 2)) === + OffsetSeqMetadata(s"""{"conf": {"$key":2}}""")) + + // Two set + assert(OffsetSeqMetadata(1, 2, Map.empty) === + OffsetSeqMetadata("""{"batchWatermarkMs":1,"batchTimestampMs":2}""")) + assert(OffsetSeqMetadata(1, 0, getConfWith(shufflePartitions = 3)) === + OffsetSeqMetadata(s"""{"batchWatermarkMs":1,"conf": {"$key":3}}""")) + assert(OffsetSeqMetadata(0, 2, getConfWith(shufflePartitions = 3)) === + OffsetSeqMetadata(s"""{"batchTimestampMs":2,"conf": {"$key":3}}""")) + + // All set + assert(OffsetSeqMetadata(1, 2, getConfWith(shufflePartitions = 3)) === + OffsetSeqMetadata(s"""{"batchWatermarkMs":1,"batchTimestampMs":2,"conf": {"$key":3}}""")) + + // Drop unknown fields + assert(OffsetSeqMetadata(1, 2, getConfWith(shufflePartitions = 3)) === + OffsetSeqMetadata( + s"""{"batchWatermarkMs":1,"batchTimestampMs":2,"conf": {"$key":3}},"unknown":1""")) } test("OffsetSeqLog - serialization - deserialization") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala index 6dfcd8baba20..e867fc40f7f1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala @@ -17,17 +17,20 @@ package org.apache.spark.sql.streaming -import java.io.{InterruptedIOException, IOException} +import java.io.{File, InterruptedIOException, IOException} import java.util.concurrent.{CountDownLatch, TimeoutException, TimeUnit} import scala.reflect.ClassTag import scala.util.control.ControlThrowable +import org.apache.commons.io.FileUtils + import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.streaming.InternalOutputModes import org.apache.spark.sql.execution.command.ExplainCommand import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.StreamSourceProvider import org.apache.spark.sql.types.{IntegerType, StructField, StructType} @@ -389,6 +392,102 @@ class StreamSuite extends StreamTest { query.stop() assert(query.exception.isEmpty) } + + test("SPARK-19873: streaming aggregation with change in number of partitions") { + val inputData = MemoryStream[(Int, Int)] + val agg = inputData.toDS().groupBy("_1").count() + + testStream(agg, OutputMode.Complete())( + AddData(inputData, (1, 0), (2, 0)), + StartStream(additionalConfs = Map(SQLConf.SHUFFLE_PARTITIONS.key -> "2")), + CheckAnswer((1, 1), (2, 1)), + StopStream, + AddData(inputData, (3, 0), (2, 0)), + StartStream(additionalConfs = Map(SQLConf.SHUFFLE_PARTITIONS.key -> "5")), + CheckAnswer((1, 1), (2, 2), (3, 1)), + StopStream, + AddData(inputData, (3, 0), (1, 0)), + StartStream(additionalConfs = Map(SQLConf.SHUFFLE_PARTITIONS.key -> "1")), + CheckAnswer((1, 2), (2, 2), (3, 2))) + } + + test("recover from a Spark v2.1 checkpoint") { + var inputData: MemoryStream[Int] = null + var query: DataStreamWriter[Row] = null + + def prepareMemoryStream(): Unit = { + inputData = MemoryStream[Int] + inputData.addData(1, 2, 3, 4) + inputData.addData(3, 4, 5, 6) + inputData.addData(5, 6, 7, 8) + + query = inputData + .toDF() + .groupBy($"value") + .agg(count("*")) + .writeStream + .outputMode("complete") + .format("memory") + } + + // Get an existing checkpoint generated by Spark v2.1. + // v2.1 does not record # shuffle partitions in the offset metadata. + val resourceUri = + this.getClass.getResource("/structured-streaming/checkpoint-version-2.1.0").toURI + val checkpointDir = new File(resourceUri) + + // 1 - Test if recovery from the checkpoint is successful. + prepareMemoryStream() + withTempDir { dir => + // Copy the checkpoint to a temp dir to prevent changes to the original. + // Not doing this will lead to the test passing on the first run, but fail subsequent runs. + FileUtils.copyDirectory(checkpointDir, dir) + + // Checkpoint data was generated by a query with 10 shuffle partitions. + // In order to test reading from the checkpoint, the checkpoint must have two or more batches, + // since the last batch may be rerun. + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "10") { + var streamingQuery: StreamingQuery = null + try { + streamingQuery = + query.queryName("counts").option("checkpointLocation", dir.getCanonicalPath).start() + streamingQuery.processAllAvailable() + inputData.addData(9) + streamingQuery.processAllAvailable() + + QueryTest.checkAnswer(spark.table("counts").toDF(), + Row("1", 1) :: Row("2", 1) :: Row("3", 2) :: Row("4", 2) :: + Row("5", 2) :: Row("6", 2) :: Row("7", 1) :: Row("8", 1) :: Row("9", 1) :: Nil) + } finally { + if (streamingQuery ne null) { + streamingQuery.stop() + } + } + } + } + + // 2 - Check recovery with wrong num shuffle partitions + prepareMemoryStream() + withTempDir { dir => + FileUtils.copyDirectory(checkpointDir, dir) + + // Since the number of partitions is greater than 10, should throw exception. + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "15") { + var streamingQuery: StreamingQuery = null + try { + intercept[StreamingQueryException] { + streamingQuery = + query.queryName("badQuery").option("checkpointLocation", dir.getCanonicalPath).start() + streamingQuery.processAllAvailable() + } + } finally { + if (streamingQuery ne null) { + streamingQuery.stop() + } + } + } + } + } } abstract class FakeSource extends StreamSourceProvider { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryManagerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryManagerSuite.scala index f05e9d1fda73..b49efa689023 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryManagerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryManagerSuite.scala @@ -239,16 +239,6 @@ class StreamingQueryManagerSuite extends StreamTest with BeforeAndAfter { } } - test("SPARK-19268: Adaptive query execution should be disallowed") { - withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") { - val e = intercept[AnalysisException] { - MemoryStream[Int].toDS.writeStream.queryName("test-query").format("memory").start() - } - assert(e.getMessage.contains(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key) && - e.getMessage.contains("not supported")) - } - } - /** Run a body of code by defining a query on each dataset */ private def withQueriesOn(datasets: Dataset[_]*)(body: Seq[StreamingQuery] => Unit): Unit = { failAfter(streamingTimeout) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala index f61dcdcbcf71..341ab0eb923d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala @@ -23,6 +23,7 @@ import java.util.concurrent.TimeUnit import scala.concurrent.duration._ import org.apache.hadoop.fs.Path +import org.mockito.Matchers.{any, eq => meq} import org.mockito.Mockito._ import org.scalatest.BeforeAndAfter @@ -370,21 +371,22 @@ class DataStreamReaderWriterSuite extends StreamTest with BeforeAndAfter { .option("checkpointLocation", checkpointLocationURI.toString) .trigger(ProcessingTime(10.seconds)) .start() + q.processAllAvailable() q.stop() verify(LastOptions.mockStreamSourceProvider).createSource( - spark.sqlContext, - s"$checkpointLocationURI/sources/0", - None, - "org.apache.spark.sql.streaming.test", - Map.empty) + any(), + meq(s"$checkpointLocationURI/sources/0"), + meq(None), + meq("org.apache.spark.sql.streaming.test"), + meq(Map.empty)) verify(LastOptions.mockStreamSourceProvider).createSource( - spark.sqlContext, - s"$checkpointLocationURI/sources/1", - None, - "org.apache.spark.sql.streaming.test", - Map.empty) + any(), + meq(s"$checkpointLocationURI/sources/1"), + meq(None), + meq("org.apache.spark.sql.streaming.test"), + meq(Map.empty)) } private def newTextInput = Utils.createTempDir(namePrefix = "text").getCanonicalPath From 6326d406b98a34e9cc8afa6743b23ee1cced8611 Mon Sep 17 00:00:00 2001 From: Jacek Laskowski Date: Fri, 17 Mar 2017 21:55:10 -0700 Subject: [PATCH 062/512] [SQL][MINOR] Fix scaladoc for UDFRegistration ## What changes were proposed in this pull request? Fix scaladoc for UDFRegistration ## How was this patch tested? local build Author: Jacek Laskowski Closes #17337 from jaceklaskowski/udfregistration-scaladoc. --- .../main/scala/org/apache/spark/sql/UDFRegistration.scala | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala index 7abfa4ea37a7..a57673334c10 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala @@ -36,7 +36,11 @@ import org.apache.spark.sql.types.{DataType, DataTypes} import org.apache.spark.util.Utils /** - * Functions for registering user-defined functions. Use `SQLContext.udf` to access this. + * Functions for registering user-defined functions. Use `SparkSession.udf` to access this: + * + * {{{ + * spark.udf + * }}} * * @note The user-defined functions must be deterministic. * From c083b6b7dec337d680b54dabeaa40e7a0f69ae69 Mon Sep 17 00:00:00 2001 From: wangzhenhua Date: Sat, 18 Mar 2017 14:07:25 +0800 Subject: [PATCH 063/512] [SPARK-19915][SQL] Exclude cartesian product candidates to reduce the search space ## What changes were proposed in this pull request? We have some concerns about removing size in the cost model [in the previous pr](https://github.com/apache/spark/pull/17240). It's a tradeoff between code structure and algorithm completeness. I tend to keep the size and thus create this new pr without changing cost model. What this pr does: 1. We only consider consecutive inner joinable items, thus excluding cartesian products in reordering procedure. This significantly reduces the search space and memory overhead of memo. Otherwise every combination of items will exist in the memo. 2. This pr also includes a bug fix: if a leaf item is a project(_, child), current solution will miss the project. ## How was this patch tested? Added test cases. Author: wangzhenhua Closes #17286 from wzhfy/joinReorder3. --- .../optimizer/CostBasedJoinReorder.scala | 191 +++++++++--------- .../apache/spark/sql/internal/SQLConf.scala | 11 + .../catalyst/optimizer/JoinReorderSuite.scala | 41 +++- 3 files changed, 143 insertions(+), 100 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala index b694561e5372..1b32bda72bc9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala @@ -19,11 +19,11 @@ package org.apache.spark.sql.catalyst.optimizer import scala.collection.mutable -import org.apache.spark.sql.catalyst.CatalystConf import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeSet, Expression, PredicateHelper} import org.apache.spark.sql.catalyst.plans.{Inner, InnerLike} import org.apache.spark.sql.catalyst.plans.logical.{BinaryNode, Join, LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.internal.SQLConf /** @@ -31,19 +31,21 @@ import org.apache.spark.sql.catalyst.rules.Rule * We may have several join reorder algorithms in the future. This class is the entry of these * algorithms, and chooses which one to use. */ -case class CostBasedJoinReorder(conf: CatalystConf) extends Rule[LogicalPlan] with PredicateHelper { +case class CostBasedJoinReorder(conf: SQLConf) extends Rule[LogicalPlan] with PredicateHelper { def apply(plan: LogicalPlan): LogicalPlan = { if (!conf.cboEnabled || !conf.joinReorderEnabled) { plan } else { - val result = plan transform { - case p @ Project(projectList, j @ Join(_, _, _: InnerLike, _)) => - reorder(p, p.outputSet) - case j @ Join(_, _, _: InnerLike, _) => + val result = plan transformDown { + // Start reordering with a joinable item, which is an InnerLike join with conditions. + case j @ Join(_, _, _: InnerLike, Some(cond)) => reorder(j, j.outputSet) + case p @ Project(projectList, Join(_, _, _: InnerLike, Some(cond))) + if projectList.forall(_.isInstanceOf[Attribute]) => + reorder(p, p.outputSet) } // After reordering is finished, convert OrderedJoin back to Join - result transform { + result transformDown { case oj: OrderedJoin => oj.join } } @@ -56,7 +58,7 @@ case class CostBasedJoinReorder(conf: CatalystConf) extends Rule[LogicalPlan] wi // We also need to check if costs of all items can be evaluated. if (items.size > 2 && items.size <= conf.joinReorderDPThreshold && conditions.nonEmpty && items.forall(_.stats(conf).rowCount.isDefined)) { - JoinReorderDP.search(conf, items, conditions, output).getOrElse(plan) + JoinReorderDP.search(conf, items, conditions, output) } else { plan } @@ -70,25 +72,26 @@ case class CostBasedJoinReorder(conf: CatalystConf) extends Rule[LogicalPlan] wi */ private def extractInnerJoins(plan: LogicalPlan): (Seq[LogicalPlan], Set[Expression]) = { plan match { - case Join(left, right, _: InnerLike, cond) => + case Join(left, right, _: InnerLike, Some(cond)) => val (leftPlans, leftConditions) = extractInnerJoins(left) val (rightPlans, rightConditions) = extractInnerJoins(right) - (leftPlans ++ rightPlans, cond.toSet.flatMap(splitConjunctivePredicates) ++ + (leftPlans ++ rightPlans, splitConjunctivePredicates(cond).toSet ++ leftConditions ++ rightConditions) - case Project(projectList, join) if projectList.forall(_.isInstanceOf[Attribute]) => - extractInnerJoins(join) + case Project(projectList, j @ Join(_, _, _: InnerLike, Some(cond))) + if projectList.forall(_.isInstanceOf[Attribute]) => + extractInnerJoins(j) case _ => (Seq(plan), Set()) } } private def replaceWithOrderedJoin(plan: LogicalPlan): LogicalPlan = plan match { - case j @ Join(left, right, _: InnerLike, cond) => + case j @ Join(left, right, _: InnerLike, Some(cond)) => val replacedLeft = replaceWithOrderedJoin(left) val replacedRight = replaceWithOrderedJoin(right) OrderedJoin(j.copy(left = replacedLeft, right = replacedRight)) - case p @ Project(_, join) => - p.copy(child = replaceWithOrderedJoin(join)) + case p @ Project(projectList, j @ Join(_, _, _: InnerLike, Some(cond))) => + p.copy(child = replaceWithOrderedJoin(j)) case _ => plan } @@ -128,10 +131,10 @@ case class CostBasedJoinReorder(conf: CatalystConf) extends Rule[LogicalPlan] wi object JoinReorderDP extends PredicateHelper { def search( - conf: CatalystConf, + conf: SQLConf, items: Seq[LogicalPlan], conditions: Set[Expression], - topOutput: AttributeSet): Option[LogicalPlan] = { + topOutput: AttributeSet): LogicalPlan = { // Level i maintains all found plans for i + 1 items. // Create the initial plans: each plan is a single item with zero cost. @@ -140,26 +143,22 @@ object JoinReorderDP extends PredicateHelper { case (item, id) => Set(id) -> JoinPlan(Set(id), item, Set(), Cost(0, 0)) }.toMap) - for (lev <- 1 until items.length) { + // Build plans for next levels until the last level has only one plan. This plan contains + // all items that can be joined, so there's no need to continue. + while (foundPlans.size < items.length && foundPlans.last.size > 1) { // Build plans for the next level. foundPlans += searchLevel(foundPlans, conf, conditions, topOutput) } - val plansLastLevel = foundPlans(items.length - 1) - if (plansLastLevel.isEmpty) { - // Failed to find a plan, fall back to the original plan - None - } else { - // There must be only one plan at the last level, which contains all items. - assert(plansLastLevel.size == 1 && plansLastLevel.head._1.size == items.length) - Some(plansLastLevel.head._2.plan) - } + // The last level must have one and only one plan, because all items are joinable. + assert(foundPlans.size == items.length && foundPlans.last.size == 1) + foundPlans.last.head._2.plan } /** Find all possible plans at the next level, based on existing levels. */ private def searchLevel( existingLevels: Seq[JoinPlanMap], - conf: CatalystConf, + conf: SQLConf, conditions: Set[Expression], topOutput: AttributeSet): JoinPlanMap = { @@ -185,11 +184,14 @@ object JoinReorderDP extends PredicateHelper { // Should not join two overlapping item sets. if (oneSidePlan.itemIds.intersect(otherSidePlan.itemIds).isEmpty) { val joinPlan = buildJoin(oneSidePlan, otherSidePlan, conf, conditions, topOutput) - // Check if it's the first plan for the item set, or it's a better plan than - // the existing one due to lower cost. - val existingPlan = nextLevel.get(joinPlan.itemIds) - if (existingPlan.isEmpty || joinPlan.cost.lessThan(existingPlan.get.cost)) { - nextLevel.update(joinPlan.itemIds, joinPlan) + if (joinPlan.isDefined) { + val newJoinPlan = joinPlan.get + // Check if it's the first plan for the item set, or it's a better plan than + // the existing one due to lower cost. + val existingPlan = nextLevel.get(newJoinPlan.itemIds) + if (existingPlan.isEmpty || newJoinPlan.betterThan(existingPlan.get, conf)) { + nextLevel.update(newJoinPlan.itemIds, newJoinPlan) + } } } } @@ -203,64 +205,46 @@ object JoinReorderDP extends PredicateHelper { private def buildJoin( oneJoinPlan: JoinPlan, otherJoinPlan: JoinPlan, - conf: CatalystConf, + conf: SQLConf, conditions: Set[Expression], - topOutput: AttributeSet): JoinPlan = { + topOutput: AttributeSet): Option[JoinPlan] = { val onePlan = oneJoinPlan.plan val otherPlan = otherJoinPlan.plan - // Now both onePlan and otherPlan become intermediate joins, so the cost of the - // new join should also include their own cardinalities and sizes. - val newCost = if (isCartesianProduct(onePlan) || isCartesianProduct(otherPlan)) { - // We consider cartesian product very expensive, thus set a very large cost for it. - // This enables to plan all the cartesian products at the end, because having a cartesian - // product as an intermediate join will significantly increase a plan's cost, making it - // impossible to be selected as the best plan for the items, unless there's no other choice. - Cost( - rows = BigInt(Long.MaxValue) * BigInt(Long.MaxValue), - size = BigInt(Long.MaxValue) * BigInt(Long.MaxValue)) - } else { - val onePlanStats = onePlan.stats(conf) - val otherPlanStats = otherPlan.stats(conf) - Cost( - rows = oneJoinPlan.cost.rows + onePlanStats.rowCount.get + - otherJoinPlan.cost.rows + otherPlanStats.rowCount.get, - size = oneJoinPlan.cost.size + onePlanStats.sizeInBytes + - otherJoinPlan.cost.size + otherPlanStats.sizeInBytes) - } - - // Put the deeper side on the left, tend to build a left-deep tree. - val (left, right) = if (oneJoinPlan.itemIds.size >= otherJoinPlan.itemIds.size) { - (onePlan, otherPlan) - } else { - (otherPlan, onePlan) - } val joinConds = conditions .filterNot(l => canEvaluate(l, onePlan)) .filterNot(r => canEvaluate(r, otherPlan)) .filter(e => e.references.subsetOf(onePlan.outputSet ++ otherPlan.outputSet)) - // We use inner join whether join condition is empty or not. Since cross join is - // equivalent to inner join without condition. - val newJoin = Join(left, right, Inner, joinConds.reduceOption(And)) - val collectedJoinConds = joinConds ++ oneJoinPlan.joinConds ++ otherJoinPlan.joinConds - val remainingConds = conditions -- collectedJoinConds - val neededAttr = AttributeSet(remainingConds.flatMap(_.references)) ++ topOutput - val neededFromNewJoin = newJoin.outputSet.filter(neededAttr.contains) - val newPlan = - if ((newJoin.outputSet -- neededFromNewJoin).nonEmpty) { - Project(neededFromNewJoin.toSeq, newJoin) + if (joinConds.isEmpty) { + // Cartesian product is very expensive, so we exclude them from candidate plans. + // This also significantly reduces the search space. + None + } else { + // Put the deeper side on the left, tend to build a left-deep tree. + val (left, right) = if (oneJoinPlan.itemIds.size >= otherJoinPlan.itemIds.size) { + (onePlan, otherPlan) } else { - newJoin + (otherPlan, onePlan) } + val newJoin = Join(left, right, Inner, joinConds.reduceOption(And)) + val collectedJoinConds = joinConds ++ oneJoinPlan.joinConds ++ otherJoinPlan.joinConds + val remainingConds = conditions -- collectedJoinConds + val neededAttr = AttributeSet(remainingConds.flatMap(_.references)) ++ topOutput + val neededFromNewJoin = newJoin.outputSet.filter(neededAttr.contains) + val newPlan = + if ((newJoin.outputSet -- neededFromNewJoin).nonEmpty) { + Project(neededFromNewJoin.toSeq, newJoin) + } else { + newJoin + } - val itemIds = oneJoinPlan.itemIds.union(otherJoinPlan.itemIds) - JoinPlan(itemIds, newPlan, collectedJoinConds, newCost) - } - - private def isCartesianProduct(plan: LogicalPlan): Boolean = plan match { - case Join(_, _, _, None) => true - case Project(_, Join(_, _, _, None)) => true - case _ => false + val itemIds = oneJoinPlan.itemIds.union(otherJoinPlan.itemIds) + // Now the root node of onePlan/otherPlan becomes an intermediate join (if it's a non-leaf + // item), so the cost of the new join should also include its own cost. + val newPlanCost = oneJoinPlan.planCost + oneJoinPlan.rootCost(conf) + + otherJoinPlan.planCost + otherJoinPlan.rootCost(conf) + Some(JoinPlan(itemIds, newPlan, collectedJoinConds, newPlanCost)) + } } /** Map[set of item ids, join plan for these items] */ @@ -272,26 +256,39 @@ object JoinReorderDP extends PredicateHelper { * @param itemIds Set of item ids participating in this partial plan. * @param plan The plan tree with the lowest cost for these items found so far. * @param joinConds Join conditions included in the plan. - * @param cost The cost of this plan is the sum of costs of all intermediate joins. + * @param planCost The cost of this plan tree is the sum of costs of all intermediate joins. */ - case class JoinPlan(itemIds: Set[Int], plan: LogicalPlan, joinConds: Set[Expression], cost: Cost) -} + case class JoinPlan( + itemIds: Set[Int], + plan: LogicalPlan, + joinConds: Set[Expression], + planCost: Cost) { -/** This class defines the cost model. */ -case class Cost(rows: BigInt, size: BigInt) { - /** - * An empirical value for the weights of cardinality (number of rows) in the cost formula: - * cost = rows * weight + size * (1 - weight), usually cardinality is more important than size. - */ - val weight = 0.7 + /** Get the cost of the root node of this plan tree. */ + def rootCost(conf: SQLConf): Cost = { + if (itemIds.size > 1) { + val rootStats = plan.stats(conf) + Cost(rootStats.rowCount.get, rootStats.sizeInBytes) + } else { + // If the plan is a leaf item, it has zero cost. + Cost(0, 0) + } + } - def lessThan(other: Cost): Boolean = { - if (other.rows == 0 || other.size == 0) { - false - } else { - val relativeRows = BigDecimal(rows) / BigDecimal(other.rows) - val relativeSize = BigDecimal(size) / BigDecimal(other.size) - relativeRows * weight + relativeSize * (1 - weight) < 1 + def betterThan(other: JoinPlan, conf: SQLConf): Boolean = { + if (other.planCost.rows == 0 || other.planCost.size == 0) { + false + } else { + val relativeRows = BigDecimal(this.planCost.rows) / BigDecimal(other.planCost.rows) + val relativeSize = BigDecimal(this.planCost.size) / BigDecimal(other.planCost.size) + relativeRows * conf.joinReorderCardWeight + + relativeSize * (1 - conf.joinReorderCardWeight) < 1 + } } } } + +/** This class defines the cost model. */ +case class Cost(rows: BigInt, size: BigInt) { + def +(other: Cost): Cost = Cost(this.rows + other.rows, this.size + other.size) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index a85f87aece45..d2ac4b88ee8f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -710,6 +710,15 @@ object SQLConf { .intConf .createWithDefault(12) + val JOIN_REORDER_CARD_WEIGHT = + buildConf("spark.sql.cbo.joinReorder.card.weight") + .internal() + .doc("The weight of cardinality (number of rows) for plan cost comparison in join reorder: " + + "rows * weight + size * (1 - weight).") + .doubleConf + .checkValue(weight => weight >= 0 && weight <= 1, "The weight value must be in [0, 1].") + .createWithDefault(0.7) + val SESSION_LOCAL_TIMEZONE = buildConf("spark.sql.session.timeZone") .doc("""The ID of session local timezone, e.g. "GMT", "America/Los_Angeles", etc.""") @@ -967,6 +976,8 @@ class SQLConf extends Serializable with Logging { def joinReorderDPThreshold: Int = getConf(SQLConf.JOIN_REORDER_DP_THRESHOLD) + def joinReorderCardWeight: Double = getConf(SQLConf.JOIN_REORDER_CARD_WEIGHT) + def windowExecBufferSpillThreshold: Int = getConf(WINDOW_EXEC_BUFFER_SPILL_THRESHOLD) def sortMergeJoinExecBufferSpillThreshold: Int = diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala index 1b2f7a66b6a0..5607bcd16f3f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala @@ -38,6 +38,7 @@ class JoinReorderSuite extends PlanTest with StatsEstimationTestBase { Batch("Operator Optimizations", FixedPoint(100), CombineFilters, PushDownPredicate, + ReorderJoin, PushPredicateThroughJoin, ColumnPruning, CollapseProject) :: @@ -58,6 +59,10 @@ class JoinReorderSuite extends PlanTest with StatsEstimationTestBase { attr("t4.k-1-2") -> ColumnStat(distinctCount = 2, min = Some(1), max = Some(2), nullCount = 0, avgLen = 4, maxLen = 4), attr("t4.v-1-10") -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), + nullCount = 0, avgLen = 4, maxLen = 4), + attr("t5.k-1-5") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), + nullCount = 0, avgLen = 4, maxLen = 4), + attr("t5.v-1-5") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), nullCount = 0, avgLen = 4, maxLen = 4) )) @@ -92,6 +97,13 @@ class JoinReorderSuite extends PlanTest with StatsEstimationTestBase { size = Some(100 * (8 + 4)), attributeStats = AttributeMap(Seq("t3.v-1-100").map(nameToColInfo))) + // Table t5: small table with two columns + private val t5 = StatsTestPlan( + outputList = Seq("t5.k-1-5", "t5.v-1-5").map(nameToAttr), + rowCount = 20, + size = Some(20 * (8 + 4)), + attributeStats = AttributeMap(Seq("t5.k-1-5", "t5.v-1-5").map(nameToColInfo))) + test("reorder 3 tables") { val originalPlan = t1.join(t2).join(t3).where((nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")) && @@ -110,13 +122,17 @@ class JoinReorderSuite extends PlanTest with StatsEstimationTestBase { assertEqualPlans(originalPlan, bestPlan) } - test("reorder 3 tables - put cross join at the end") { + test("put unjoinable item at the end and reorder 3 joinable tables") { + // The ReorderJoin rule puts the unjoinable item at the end, and then CostBasedJoinReorder + // reorders other joinable items. val originalPlan = - t1.join(t2).join(t3).where(nameToAttr("t1.v-1-10") === nameToAttr("t3.v-1-100")) + t1.join(t2).join(t4).join(t3).where((nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")) && + (nameToAttr("t1.v-1-10") === nameToAttr("t3.v-1-100"))) val bestPlan = t1.join(t3, Inner, Some(nameToAttr("t1.v-1-10") === nameToAttr("t3.v-1-100"))) - .join(t2, Inner, None) + .join(t2, Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5"))) + .join(t4) assertEqualPlans(originalPlan, bestPlan) } @@ -136,6 +152,23 @@ class JoinReorderSuite extends PlanTest with StatsEstimationTestBase { assertEqualPlans(originalPlan, bestPlan) } + test("reorder 3 tables - one of the leaf items is a project") { + val originalPlan = + t1.join(t5).join(t3).where((nameToAttr("t1.k-1-2") === nameToAttr("t5.k-1-5")) && + (nameToAttr("t1.v-1-10") === nameToAttr("t3.v-1-100"))) + .select(nameToAttr("t1.v-1-10")) + + // Items: t1, t3, project(t5.k-1-5, t5) + val bestPlan = + t1.join(t3, Inner, Some(nameToAttr("t1.v-1-10") === nameToAttr("t3.v-1-100"))) + .select(nameToAttr("t1.k-1-2"), nameToAttr("t1.v-1-10")) + .join(t5.select(nameToAttr("t5.k-1-5")), Inner, + Some(nameToAttr("t1.k-1-2") === nameToAttr("t5.k-1-5"))) + .select(nameToAttr("t1.v-1-10")) + + assertEqualPlans(originalPlan, bestPlan) + } + test("don't reorder if project contains non-attribute") { val originalPlan = t1.join(t2, Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5"))) @@ -187,6 +220,8 @@ class JoinReorderSuite extends PlanTest with StatsEstimationTestBase { case (j1: Join, j2: Join) => (sameJoinPlan(j1.left, j2.left) && sameJoinPlan(j1.right, j2.right)) || (sameJoinPlan(j1.left, j2.right) && sameJoinPlan(j1.right, j2.left)) + case _ if plan1.children.nonEmpty && plan2.children.nonEmpty => + (plan1.children, plan2.children).zipped.forall { case (c1, c2) => sameJoinPlan(c1, c2) } case _ => plan1 == plan2 } From ccba622e35741d8344ec8d74b6750529b2c7219b Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Sat, 18 Mar 2017 14:40:16 +0800 Subject: [PATCH 064/512] [SPARK-19896][SQL] Throw an exception if case classes have circular references in toDS ## What changes were proposed in this pull request? If case classes have circular references below, it throws StackOverflowError; ``` scala> :pasge case class classA(i: Int, cls: classB) case class classB(cls: classA) scala> Seq(classA(0, null)).toDS() java.lang.StackOverflowError at scala.reflect.internal.Symbols$Symbol.info(Symbols.scala:1494) at scala.reflect.runtime.JavaMirrors$JavaMirror$$anon$1.scala$reflect$runtime$SynchronizedSymbols$SynchronizedSymbol$$super$info(JavaMirrors.scala:66) at scala.reflect.runtime.SynchronizedSymbols$SynchronizedSymbol$$anonfun$info$1.apply(SynchronizedSymbols.scala:127) at scala.reflect.runtime.SynchronizedSymbols$SynchronizedSymbol$$anonfun$info$1.apply(SynchronizedSymbols.scala:127) at scala.reflect.runtime.Gil$class.gilSynchronized(Gil.scala:19) at scala.reflect.runtime.JavaUniverse.gilSynchronized(JavaUniverse.scala:16) at scala.reflect.runtime.SynchronizedSymbols$SynchronizedSymbol$class.gilSynchronizedIfNotThreadsafe(SynchronizedSymbols.scala:123) at scala.reflect.runtime.JavaMirrors$JavaMirror$$anon$1.gilSynchronizedIfNotThreadsafe(JavaMirrors.scala:66) at scala.reflect.runtime.SynchronizedSymbols$SynchronizedSymbol$class.info(SynchronizedSymbols.scala:127) at scala.reflect.runtime.JavaMirrors$JavaMirror$$anon$1.info(JavaMirrors.scala:66) at scala.reflect.internal.Mirrors$RootsBase.getModuleOrClass(Mirrors.scala:48) at scala.reflect.internal.Mirrors$RootsBase.getModuleOrClass(Mirrors.scala:45) at scala.reflect.internal.Mirrors$RootsBase.getModuleOrClass(Mirrors.scala:45) at scala.reflect.internal.Mirrors$RootsBase.getModuleOrClass(Mirrors.scala:45) at scala.reflect.internal.Mirrors$RootsBase.getModuleOrClass(Mirrors.scala:45) ``` This pr added code to throw UnsupportedOperationException in that case as follows; ``` scala> :paste case class A(cls: B) case class B(cls: A) scala> Seq(A(null)).toDS() java.lang.UnsupportedOperationException: cannot have circular references in class, but got the circular reference of class B at org.apache.spark.sql.catalyst.ScalaReflection$.org$apache$spark$sql$catalyst$ScalaReflection$$serializerFor(ScalaReflection.scala:627) at org.apache.spark.sql.catalyst.ScalaReflection$$anonfun$9.apply(ScalaReflection.scala:644) at org.apache.spark.sql.catalyst.ScalaReflection$$anonfun$9.apply(ScalaReflection.scala:632) at scala.collection.TraversableLike$$anonfun$flatMap$1.apply(TraversableLike.scala:241) at scala.collection.TraversableLike$$anonfun$flatMap$1.apply(TraversableLike.scala:241) at scala.collection.immutable.List.foreach(List.scala:381) at scala.collection.TraversableLike$class.flatMap(TraversableLike.scala:241) ``` ## How was this patch tested? Added tests in `DatasetSuite`. Author: Takeshi Yamamuro Closes #17318 from maropu/SPARK-19896. --- .../spark/sql/catalyst/ScalaReflection.scala | 20 ++++++++++------ .../org/apache/spark/sql/DatasetSuite.scala | 24 +++++++++++++++++++ 2 files changed, 37 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 7f7dd51aa265..c4af284f73d1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -470,14 +470,15 @@ object ScalaReflection extends ScalaReflection { private def serializerFor( inputObject: Expression, tpe: `Type`, - walkedTypePath: Seq[String]): Expression = ScalaReflectionLock.synchronized { + walkedTypePath: Seq[String], + seenTypeSet: Set[`Type`] = Set.empty): Expression = ScalaReflectionLock.synchronized { def toCatalystArray(input: Expression, elementType: `Type`): Expression = { dataTypeFor(elementType) match { case dt: ObjectType => val clsName = getClassNameFromType(elementType) val newPath = s"""- array element class: "$clsName"""" +: walkedTypePath - MapObjects(serializerFor(_, elementType, newPath), input, dt) + MapObjects(serializerFor(_, elementType, newPath, seenTypeSet), input, dt) case dt @ (BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType) => @@ -511,7 +512,7 @@ object ScalaReflection extends ScalaReflection { val className = getClassNameFromType(optType) val newPath = s"""- option value class: "$className"""" +: walkedTypePath val unwrapped = UnwrapOption(dataTypeFor(optType), inputObject) - serializerFor(unwrapped, optType, newPath) + serializerFor(unwrapped, optType, newPath, seenTypeSet) // Since List[_] also belongs to localTypeOf[Product], we put this case before // "case t if definedByConstructorParams(t)" to make sure it will match to the @@ -534,9 +535,9 @@ object ScalaReflection extends ScalaReflection { ExternalMapToCatalyst( inputObject, dataTypeFor(keyType), - serializerFor(_, keyType, keyPath), + serializerFor(_, keyType, keyPath, seenTypeSet), dataTypeFor(valueType), - serializerFor(_, valueType, valuePath), + serializerFor(_, valueType, valuePath, seenTypeSet), valueNullable = !valueType.typeSymbol.asClass.isPrimitive) case t if t <:< localTypeOf[String] => @@ -622,6 +623,11 @@ object ScalaReflection extends ScalaReflection { Invoke(obj, "serialize", udt, inputObject :: Nil) case t if definedByConstructorParams(t) => + if (seenTypeSet.contains(t)) { + throw new UnsupportedOperationException( + s"cannot have circular references in class, but got the circular reference of class $t") + } + val params = getConstructorParameters(t) val nonNullOutput = CreateNamedStruct(params.flatMap { case (fieldName, fieldType) => if (javaKeywords.contains(fieldName)) { @@ -634,7 +640,8 @@ object ScalaReflection extends ScalaReflection { returnNullable = !fieldType.typeSymbol.asClass.isPrimitive) val clsName = getClassNameFromType(fieldType) val newPath = s"""- field (class: "$clsName", name: "$fieldName")""" +: walkedTypePath - expressions.Literal(fieldName) :: serializerFor(fieldValue, fieldType, newPath) :: Nil + expressions.Literal(fieldName) :: + serializerFor(fieldValue, fieldType, newPath, seenTypeSet + t) :: Nil }) val nullOutput = expressions.Literal.create(null, nonNullOutput.dataType) expressions.If(IsNull(inputObject), nullOutput, nonNullOutput) @@ -643,7 +650,6 @@ object ScalaReflection extends ScalaReflection { throw new UnsupportedOperationException( s"No Encoder found for $tpe\n" + walkedTypePath.mkString("\n")) } - } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index b37bf131e8dc..6417e7a8b603 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -1136,6 +1136,24 @@ class DatasetSuite extends QueryTest with SharedSQLContext { assert(spark.range(1).map { x => new java.sql.Timestamp(100000) }.head == new java.sql.Timestamp(100000)) } + + test("SPARK-19896: cannot have circular references in in case class") { + val errMsg1 = intercept[UnsupportedOperationException] { + Seq(CircularReferenceClassA(null)).toDS + } + assert(errMsg1.getMessage.startsWith("cannot have circular references in class, but got the " + + "circular reference of class")) + val errMsg2 = intercept[UnsupportedOperationException] { + Seq(CircularReferenceClassC(null)).toDS + } + assert(errMsg2.getMessage.startsWith("cannot have circular references in class, but got the " + + "circular reference of class")) + val errMsg3 = intercept[UnsupportedOperationException] { + Seq(CircularReferenceClassD(null)).toDS + } + assert(errMsg3.getMessage.startsWith("cannot have circular references in class, but got the " + + "circular reference of class")) + } } case class WithImmutableMap(id: String, map_test: scala.collection.immutable.Map[Long, String]) @@ -1214,3 +1232,9 @@ object DatasetTransform { case class Route(src: String, dest: String, cost: Int) case class GroupedRoutes(src: String, dest: String, routes: Seq[Route]) + +case class CircularReferenceClassA(cls: CircularReferenceClassB) +case class CircularReferenceClassB(cls: CircularReferenceClassA) +case class CircularReferenceClassC(ar: Array[CircularReferenceClassC]) +case class CircularReferenceClassD(map: Map[String, CircularReferenceClassE]) +case class CircularReferenceClassE(id: String, list: List[CircularReferenceClassD]) From 54e61df2634163382c7d01a2ad40ffb5e7270abc Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Sat, 18 Mar 2017 18:01:24 +0100 Subject: [PATCH 065/512] [SPARK-16599][CORE] java.util.NoSuchElementException: None.get at at org.apache.spark.storage.BlockInfoManager.releaseAllLocksForTask ## What changes were proposed in this pull request? Avoid None.get exception in (rare?) case that no readLocks exist Note that while this would resolve the immediate cause of the exception, it's not clear it is the root problem. ## How was this patch tested? Existing tests Author: Sean Owen Closes #17290 from srowen/SPARK-16599. --- .../scala/org/apache/spark/storage/BlockInfoManager.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockInfoManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockInfoManager.scala index dd8f5bacb9f6..490d45d12b8e 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockInfoManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockInfoManager.scala @@ -23,7 +23,7 @@ import scala.collection.JavaConverters._ import scala.collection.mutable import scala.reflect.ClassTag -import com.google.common.collect.ConcurrentHashMultiset +import com.google.common.collect.{ConcurrentHashMultiset, ImmutableMultiset} import org.apache.spark.{SparkException, TaskContext} import org.apache.spark.internal.Logging @@ -340,7 +340,7 @@ private[storage] class BlockInfoManager extends Logging { val blocksWithReleasedLocks = mutable.ArrayBuffer[BlockId]() val readLocks = synchronized { - readLocksByTask.remove(taskAttemptId).get + readLocksByTask.remove(taskAttemptId).getOrElse(ImmutableMultiset.of[BlockId]()) } val writeLocks = synchronized { writeLocksByTask.remove(taskAttemptId).getOrElse(Seq.empty) From 5c165596dac136b9b3a88cfb3578b2423d227eb7 Mon Sep 17 00:00:00 2001 From: Felix Cheung Date: Sat, 18 Mar 2017 16:26:48 -0700 Subject: [PATCH 066/512] [SPARK-19654][SPARKR][SS] Structured Streaming API for R ## What changes were proposed in this pull request? Add "experimental" API for SS in R ## How was this patch tested? manual, unit tests Author: Felix Cheung Closes #16982 from felixcheung/rss. --- R/pkg/DESCRIPTION | 1 + R/pkg/NAMESPACE | 13 ++ R/pkg/R/DataFrame.R | 104 ++++++++++- R/pkg/R/SQLContext.R | 50 +++++ R/pkg/R/generics.R | 41 +++- R/pkg/R/streaming.R | 208 +++++++++++++++++++++ R/pkg/R/utils.R | 11 +- R/pkg/inst/tests/testthat/test_streaming.R | 150 +++++++++++++++ 8 files changed, 573 insertions(+), 5 deletions(-) create mode 100644 R/pkg/R/streaming.R create mode 100644 R/pkg/inst/tests/testthat/test_streaming.R diff --git a/R/pkg/DESCRIPTION b/R/pkg/DESCRIPTION index cc471edc376b..1635f71489aa 100644 --- a/R/pkg/DESCRIPTION +++ b/R/pkg/DESCRIPTION @@ -54,5 +54,6 @@ Collate: 'types.R' 'utils.R' 'window.R' + 'streaming.R' RoxygenNote: 5.0.1 VignetteBuilder: knitr diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 871f8e41a0f2..78344ce9ff08 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -121,6 +121,7 @@ exportMethods("arrange", "insertInto", "intersect", "isLocal", + "isStreaming", "join", "limit", "merge", @@ -169,6 +170,7 @@ exportMethods("arrange", "write.json", "write.orc", "write.parquet", + "write.stream", "write.text", "write.ml") @@ -365,6 +367,7 @@ export("as.DataFrame", "read.json", "read.orc", "read.parquet", + "read.stream", "read.text", "spark.lapply", "spark.addFile", @@ -402,6 +405,16 @@ export("partitionBy", export("windowPartitionBy", "windowOrderBy") +exportClasses("StreamingQuery") + +export("awaitTermination", + "isActive", + "lastProgress", + "queryName", + "status", + "stopQuery") + + S3method(print, jobj) S3method(print, structField) S3method(print, structType) diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 97e0c9edeab4..bc81633815c6 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -133,9 +133,6 @@ setMethod("schema", #' #' Print the logical and physical Catalyst plans to the console for debugging. #' -#' @param x a SparkDataFrame. -#' @param extended Logical. If extended is FALSE, explain() only prints the physical plan. -#' @param ... further arguments to be passed to or from other methods. #' @family SparkDataFrame functions #' @aliases explain,SparkDataFrame-method #' @rdname explain @@ -3515,3 +3512,104 @@ setMethod("getNumPartitions", function(x) { callJMethod(callJMethod(x@sdf, "rdd"), "getNumPartitions") }) + +#' isStreaming +#' +#' Returns TRUE if this SparkDataFrame contains one or more sources that continuously return data +#' as it arrives. +#' +#' @param x A SparkDataFrame +#' @return TRUE if this SparkDataFrame is from a streaming source +#' @family SparkDataFrame functions +#' @aliases isStreaming,SparkDataFrame-method +#' @rdname isStreaming +#' @name isStreaming +#' @seealso \link{read.stream} \link{write.stream} +#' @export +#' @examples +#'\dontrun{ +#' sparkR.session() +#' df <- read.stream("socket", host = "localhost", port = 9999) +#' isStreaming(df) +#' } +#' @note isStreaming since 2.2.0 +#' @note experimental +setMethod("isStreaming", + signature(x = "SparkDataFrame"), + function(x) { + callJMethod(x@sdf, "isStreaming") + }) + +#' Write the streaming SparkDataFrame to a data source. +#' +#' The data source is specified by the \code{source} and a set of options (...). +#' If \code{source} is not specified, the default data source configured by +#' spark.sql.sources.default will be used. +#' +#' Additionally, \code{outputMode} specifies how data of a streaming SparkDataFrame is written to a +#' output data source. There are three modes: +#' \itemize{ +#' \item append: Only the new rows in the streaming SparkDataFrame will be written out. This +#' output mode can be only be used in queries that do not contain any aggregation. +#' \item complete: All the rows in the streaming SparkDataFrame will be written out every time +#' there are some updates. This output mode can only be used in queries that +#' contain aggregations. +#' \item update: Only the rows that were updated in the streaming SparkDataFrame will be written +#' out every time there are some updates. If the query doesn't contain aggregations, +#' it will be equivalent to \code{append} mode. +#' } +#' +#' @param df a streaming SparkDataFrame. +#' @param source a name for external data source. +#' @param outputMode one of 'append', 'complete', 'update'. +#' @param ... additional argument(s) passed to the method. +#' +#' @family SparkDataFrame functions +#' @seealso \link{read.stream} +#' @aliases write.stream,SparkDataFrame-method +#' @rdname write.stream +#' @name write.stream +#' @export +#' @examples +#'\dontrun{ +#' sparkR.session() +#' df <- read.stream("socket", host = "localhost", port = 9999) +#' isStreaming(df) +#' wordCounts <- count(group_by(df, "value")) +#' +#' # console +#' q <- write.stream(wordCounts, "console", outputMode = "complete") +#' # text stream +#' q <- write.stream(df, "text", path = "/home/user/out", checkpointLocation = "/home/user/cp") +#' # memory stream +#' q <- write.stream(wordCounts, "memory", queryName = "outs", outputMode = "complete") +#' head(sql("SELECT * from outs")) +#' queryName(q) +#' +#' stopQuery(q) +#' } +#' @note write.stream since 2.2.0 +#' @note experimental +setMethod("write.stream", + signature(df = "SparkDataFrame"), + function(df, source = NULL, outputMode = NULL, ...) { + if (!is.null(source) && !is.character(source)) { + stop("source should be character, NULL or omitted. It is the data source specified ", + "in 'spark.sql.sources.default' configuration by default.") + } + if (!is.null(outputMode) && !is.character(outputMode)) { + stop("outputMode should be charactor or omitted.") + } + if (is.null(source)) { + source <- getDefaultSqlSource() + } + options <- varargsToStrEnv(...) + write <- handledCallJMethod(df@sdf, "writeStream") + write <- callJMethod(write, "format", source) + if (!is.null(outputMode)) { + write <- callJMethod(write, "outputMode", outputMode) + } + write <- callJMethod(write, "options", options) + ssq <- handledCallJMethod(write, "start") + streamingQuery(ssq) + }) diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R index 8354f705f6de..b75fb0159d50 100644 --- a/R/pkg/R/SQLContext.R +++ b/R/pkg/R/SQLContext.R @@ -937,3 +937,53 @@ read.jdbc <- function(url, tableName, } dataFrame(sdf) } + +#' Load a streaming SparkDataFrame +#' +#' Returns the dataset in a data source as a SparkDataFrame +#' +#' The data source is specified by the \code{source} and a set of options(...). +#' If \code{source} is not specified, the default data source configured by +#' "spark.sql.sources.default" will be used. +#' +#' @param source The name of external data source +#' @param schema The data schema defined in structType, this is required for file-based streaming +#' data source +#' @param ... additional external data source specific named options, for instance \code{path} for +#' file-based streaming data source +#' @return SparkDataFrame +#' @rdname read.stream +#' @name read.stream +#' @seealso \link{write.stream} +#' @export +#' @examples +#'\dontrun{ +#' sparkR.session() +#' df <- read.stream("socket", host = "localhost", port = 9999) +#' q <- write.stream(df, "text", path = "/home/user/out", checkpointLocation = "/home/user/cp") +#' +#' df <- read.stream("json", path = jsonDir, schema = schema, maxFilesPerTrigger = 1) +#' } +#' @name read.stream +#' @note read.stream since 2.2.0 +#' @note experimental +read.stream <- function(source = NULL, schema = NULL, ...) { + sparkSession <- getSparkSession() + if (!is.null(source) && !is.character(source)) { + stop("source should be character, NULL or omitted. It is the data source specified ", + "in 'spark.sql.sources.default' configuration by default.") + } + if (is.null(source)) { + source <- getDefaultSqlSource() + } + options <- varargsToStrEnv(...) + read <- callJMethod(sparkSession, "readStream") + read <- callJMethod(read, "format", source) + if (!is.null(schema)) { + stopifnot(class(schema) == "structType") + read <- callJMethod(read, "schema", schema$jobj) + } + read <- callJMethod(read, "options", options) + sdf <- handledCallJMethod(read, "load") + dataFrame(callJMethod(sdf, "toDF")) +} diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 45bc12746511..029771289fd5 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -539,6 +539,9 @@ setGeneric("dtypes", function(x) { standardGeneric("dtypes") }) #' @rdname explain #' @export +#' @param x a SparkDataFrame or a StreamingQuery. +#' @param extended Logical. If extended is FALSE, prints only the physical plan. +#' @param ... further arguments to be passed to or from other methods. setGeneric("explain", function(x, ...) { standardGeneric("explain") }) #' @rdname except @@ -577,6 +580,10 @@ setGeneric("intersect", function(x, y) { standardGeneric("intersect") }) #' @export setGeneric("isLocal", function(x) { standardGeneric("isLocal") }) +#' @rdname isStreaming +#' @export +setGeneric("isStreaming", function(x) { standardGeneric("isStreaming") }) + #' @rdname limit #' @export setGeneric("limit", function(x, num) {standardGeneric("limit") }) @@ -682,6 +689,12 @@ setGeneric("write.parquet", function(x, path, ...) { #' @export setGeneric("saveAsParquetFile", function(x, path) { standardGeneric("saveAsParquetFile") }) +#' @rdname write.stream +#' @export +setGeneric("write.stream", function(df, source = NULL, outputMode = NULL, ...) { + standardGeneric("write.stream") +}) + #' @rdname write.text #' @export setGeneric("write.text", function(x, path, ...) { standardGeneric("write.text") }) @@ -1428,10 +1441,36 @@ setGeneric("spark.posterior", function(object, newData) { standardGeneric("spark #' @export setGeneric("spark.perplexity", function(object, data) { standardGeneric("spark.perplexity") }) - #' @param object a fitted ML model object. #' @param path the directory where the model is saved. #' @param ... additional argument(s) passed to the method. #' @rdname write.ml #' @export setGeneric("write.ml", function(object, path, ...) { standardGeneric("write.ml") }) + + +###################### Streaming Methods ########################## + +#' @rdname awaitTermination +#' @export +setGeneric("awaitTermination", function(x, timeout) { standardGeneric("awaitTermination") }) + +#' @rdname isActive +#' @export +setGeneric("isActive", function(x) { standardGeneric("isActive") }) + +#' @rdname lastProgress +#' @export +setGeneric("lastProgress", function(x) { standardGeneric("lastProgress") }) + +#' @rdname queryName +#' @export +setGeneric("queryName", function(x) { standardGeneric("queryName") }) + +#' @rdname status +#' @export +setGeneric("status", function(x) { standardGeneric("status") }) + +#' @rdname stopQuery +#' @export +setGeneric("stopQuery", function(x) { standardGeneric("stopQuery") }) diff --git a/R/pkg/R/streaming.R b/R/pkg/R/streaming.R new file mode 100644 index 000000000000..e353d2dd07c3 --- /dev/null +++ b/R/pkg/R/streaming.R @@ -0,0 +1,208 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# streaming.R - Structured Streaming / StreamingQuery class and methods implemented in S4 OO classes + +#' @include generics.R jobj.R +NULL + +#' S4 class that represents a StreamingQuery +#' +#' StreamingQuery can be created by using read.stream() and write.stream() +#' +#' @rdname StreamingQuery +#' @seealso \link{read.stream} +#' +#' @param ssq A Java object reference to the backing Scala StreamingQuery +#' @export +#' @note StreamingQuery since 2.2.0 +#' @note experimental +setClass("StreamingQuery", + slots = list(ssq = "jobj")) + +setMethod("initialize", "StreamingQuery", function(.Object, ssq) { + .Object@ssq <- ssq + .Object +}) + +streamingQuery <- function(ssq) { + stopifnot(class(ssq) == "jobj") + new("StreamingQuery", ssq) +} + +#' @rdname show +#' @export +#' @note show(StreamingQuery) since 2.2.0 +setMethod("show", "StreamingQuery", + function(object) { + name <- callJMethod(object@ssq, "name") + if (!is.null(name)) { + cat(paste0("StreamingQuery '", name, "'\n")) + } else { + cat("StreamingQuery", "\n") + } + }) + +#' queryName +#' +#' Returns the user-specified name of the query. This is specified in +#' \code{write.stream(df, queryName = "query")}. This name, if set, must be unique across all active +#' queries. +#' +#' @param x a StreamingQuery. +#' @return The name of the query, or NULL if not specified. +#' @rdname queryName +#' @name queryName +#' @aliases queryName,StreamingQuery-method +#' @family StreamingQuery methods +#' @seealso \link{write.stream} +#' @export +#' @examples +#' \dontrun{ queryName(sq) } +#' @note queryName(StreamingQuery) since 2.2.0 +#' @note experimental +setMethod("queryName", + signature(x = "StreamingQuery"), + function(x) { + callJMethod(x@ssq, "name") + }) + +#' @rdname explain +#' @name explain +#' @aliases explain,StreamingQuery-method +#' @family StreamingQuery methods +#' @export +#' @examples +#' \dontrun{ explain(sq) } +#' @note explain(StreamingQuery) since 2.2.0 +setMethod("explain", + signature(x = "StreamingQuery"), + function(x, extended = FALSE) { + cat(callJMethod(x@ssq, "explainInternal", extended), "\n") + }) + +#' lastProgress +#' +#' Prints the most recent progess update of this streaming query in JSON format. +#' +#' @param x a StreamingQuery. +#' @rdname lastProgress +#' @name lastProgress +#' @aliases lastProgress,StreamingQuery-method +#' @family StreamingQuery methods +#' @export +#' @examples +#' \dontrun{ lastProgress(sq) } +#' @note lastProgress(StreamingQuery) since 2.2.0 +#' @note experimental +setMethod("lastProgress", + signature(x = "StreamingQuery"), + function(x) { + p <- callJMethod(x@ssq, "lastProgress") + if (is.null(p)) { + cat("Streaming query has no progress") + } else { + cat(callJMethod(p, "toString"), "\n") + } + }) + +#' status +#' +#' Prints the current status of the query in JSON format. +#' +#' @param x a StreamingQuery. +#' @rdname status +#' @name status +#' @aliases status,StreamingQuery-method +#' @family StreamingQuery methods +#' @export +#' @examples +#' \dontrun{ status(sq) } +#' @note status(StreamingQuery) since 2.2.0 +#' @note experimental +setMethod("status", + signature(x = "StreamingQuery"), + function(x) { + cat(callJMethod(callJMethod(x@ssq, "status"), "toString"), "\n") + }) + +#' isActive +#' +#' Returns TRUE if this query is actively running. +#' +#' @param x a StreamingQuery. +#' @return TRUE if query is actively running, FALSE if stopped. +#' @rdname isActive +#' @name isActive +#' @aliases isActive,StreamingQuery-method +#' @family StreamingQuery methods +#' @export +#' @examples +#' \dontrun{ isActive(sq) } +#' @note isActive(StreamingQuery) since 2.2.0 +#' @note experimental +setMethod("isActive", + signature(x = "StreamingQuery"), + function(x) { + callJMethod(x@ssq, "isActive") + }) + +#' awaitTermination +#' +#' Waits for the termination of the query, either by \code{stopQuery} or by an error. +#' +#' If the query has terminated, then all subsequent calls to this method will return TRUE +#' immediately. +#' +#' @param x a StreamingQuery. +#' @param timeout time to wait in milliseconds +#' @return TRUE if query has terminated within the timeout period. +#' @rdname awaitTermination +#' @name awaitTermination +#' @aliases awaitTermination,StreamingQuery-method +#' @family StreamingQuery methods +#' @export +#' @examples +#' \dontrun{ awaitTermination(sq, 10000) } +#' @note awaitTermination(StreamingQuery) since 2.2.0 +#' @note experimental +setMethod("awaitTermination", + signature(x = "StreamingQuery"), + function(x, timeout) { + handledCallJMethod(x@ssq, "awaitTermination", as.integer(timeout)) + }) + +#' stopQuery +#' +#' Stops the execution of this query if it is running. This method blocks until the execution is +#' stopped. +#' +#' @param x a StreamingQuery. +#' @rdname stopQuery +#' @name stopQuery +#' @aliases stopQuery,StreamingQuery-method +#' @family StreamingQuery methods +#' @export +#' @examples +#' \dontrun{ stopQuery(sq) } +#' @note stopQuery(StreamingQuery) since 2.2.0 +#' @note experimental +setMethod("stopQuery", + signature(x = "StreamingQuery"), + function(x) { + invisible(callJMethod(x@ssq, "stop")) + }) diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R index 1f7848f2b413..810de9917e0b 100644 --- a/R/pkg/R/utils.R +++ b/R/pkg/R/utils.R @@ -823,7 +823,16 @@ captureJVMException <- function(e, method) { stacktrace <- rawmsg } - if (any(grep("java.lang.IllegalArgumentException: ", stacktrace))) { + # StreamingQueryException could wrap an IllegalArgumentException, so look for that first + if (any(grep("org.apache.spark.sql.streaming.StreamingQueryException: ", stacktrace))) { + msg <- strsplit(stacktrace, "org.apache.spark.sql.streaming.StreamingQueryException: ", + fixed = TRUE)[[1]] + # Extract "Error in ..." message. + rmsg <- msg[1] + # Extract the first message of JVM exception. + first <- strsplit(msg[2], "\r?\n\tat")[[1]][1] + stop(paste0(rmsg, "streaming query error - ", first), call. = FALSE) + } else if (any(grep("java.lang.IllegalArgumentException: ", stacktrace))) { msg <- strsplit(stacktrace, "java.lang.IllegalArgumentException: ", fixed = TRUE)[[1]] # Extract "Error in ..." message. rmsg <- msg[1] diff --git a/R/pkg/inst/tests/testthat/test_streaming.R b/R/pkg/inst/tests/testthat/test_streaming.R new file mode 100644 index 000000000000..03b1bd3dc1f4 --- /dev/null +++ b/R/pkg/inst/tests/testthat/test_streaming.R @@ -0,0 +1,150 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +library(testthat) + +context("Structured Streaming") + +# Tests for Structured Streaming functions in SparkR + +sparkSession <- sparkR.session(enableHiveSupport = FALSE) + +jsonSubDir <- file.path("sparkr-test", "json", "") +if (.Platform$OS.type == "windows") { + # file.path removes the empty separator on Windows, adds it back + jsonSubDir <- paste0(jsonSubDir, .Platform$file.sep) +} +jsonDir <- file.path(tempdir(), jsonSubDir) +dir.create(jsonDir, recursive = TRUE) + +mockLines <- c("{\"name\":\"Michael\"}", + "{\"name\":\"Andy\", \"age\":30}", + "{\"name\":\"Justin\", \"age\":19}") +jsonPath <- tempfile(pattern = jsonSubDir, fileext = ".tmp") +writeLines(mockLines, jsonPath) + +mockLinesNa <- c("{\"name\":\"Bob\",\"age\":16,\"height\":176.5}", + "{\"name\":\"Alice\",\"age\":null,\"height\":164.3}", + "{\"name\":\"David\",\"age\":60,\"height\":null}") +jsonPathNa <- tempfile(pattern = jsonSubDir, fileext = ".tmp") + +schema <- structType(structField("name", "string"), + structField("age", "integer"), + structField("count", "double")) + +test_that("read.stream, write.stream, awaitTermination, stopQuery", { + df <- read.stream("json", path = jsonDir, schema = schema, maxFilesPerTrigger = 1) + expect_true(isStreaming(df)) + counts <- count(group_by(df, "name")) + q <- write.stream(counts, "memory", queryName = "people", outputMode = "complete") + + expect_false(awaitTermination(q, 5 * 1000)) + expect_equal(head(sql("SELECT count(*) FROM people"))[[1]], 3) + + writeLines(mockLinesNa, jsonPathNa) + awaitTermination(q, 5 * 1000) + expect_equal(head(sql("SELECT count(*) FROM people"))[[1]], 6) + + stopQuery(q) + expect_true(awaitTermination(q, 1)) +}) + +test_that("print from explain, lastProgress, status, isActive", { + df <- read.stream("json", path = jsonDir, schema = schema) + expect_true(isStreaming(df)) + counts <- count(group_by(df, "name")) + q <- write.stream(counts, "memory", queryName = "people2", outputMode = "complete") + + awaitTermination(q, 5 * 1000) + + expect_equal(capture.output(explain(q))[[1]], "== Physical Plan ==") + expect_true(any(grepl("\"description\" : \"MemorySink\"", capture.output(lastProgress(q))))) + expect_true(any(grepl("\"isTriggerActive\" : ", capture.output(status(q))))) + + expect_equal(queryName(q), "people2") + expect_true(isActive(q)) + + stopQuery(q) +}) + +test_that("Stream other format", { + parquetPath <- tempfile(pattern = "sparkr-test", fileext = ".parquet") + df <- read.df(jsonPath, "json", schema) + write.df(df, parquetPath, "parquet", "overwrite") + + df <- read.stream(path = parquetPath, schema = schema) + expect_true(isStreaming(df)) + counts <- count(group_by(df, "name")) + q <- write.stream(counts, "memory", queryName = "people3", outputMode = "complete") + + expect_false(awaitTermination(q, 5 * 1000)) + expect_equal(head(sql("SELECT count(*) FROM people3"))[[1]], 3) + + expect_equal(queryName(q), "people3") + expect_true(any(grepl("\"description\" : \"FileStreamSource[[:print:]]+parquet", + capture.output(lastProgress(q))))) + expect_true(isActive(q)) + + stopQuery(q) + expect_true(awaitTermination(q, 1)) + expect_false(isActive(q)) + + unlink(parquetPath) +}) + +test_that("Non-streaming DataFrame", { + c <- as.DataFrame(cars) + expect_false(isStreaming(c)) + + expect_error(write.stream(c, "memory", queryName = "people", outputMode = "complete"), + paste0(".*(writeStream : analysis error - 'writeStream' can be called only on ", + "streaming Dataset/DataFrame).*")) +}) + +test_that("Unsupported operation", { + # memory sink without aggregation + df <- read.stream("json", path = jsonDir, schema = schema, maxFilesPerTrigger = 1) + expect_error(write.stream(df, "memory", queryName = "people", outputMode = "complete"), + paste0(".*(start : analysis error - Complete output mode not supported when there ", + "are no streaming aggregations on streaming DataFrames/Datasets).*")) +}) + +test_that("Terminated by error", { + df <- read.stream("json", path = jsonDir, schema = schema, maxFilesPerTrigger = -1) + counts <- count(group_by(df, "name")) + # This would not fail before returning with a StreamingQuery, + # but could dump error log at just about the same time + expect_error(q <- write.stream(counts, "memory", queryName = "people4", outputMode = "complete"), + NA) + + expect_error(awaitTermination(q, 1), + paste0(".*(awaitTermination : streaming query error - Invalid value '-1' for option", + " 'maxFilesPerTrigger', must be a positive integer).*")) + + expect_true(any(grepl("\"message\" : \"Terminated with exception: Invalid value", + capture.output(status(q))))) + expect_true(any(grepl("Streaming query has no progress", capture.output(lastProgress(q))))) + expect_equal(queryName(q), "people4") + expect_false(isActive(q)) + + stopQuery(q) +}) + +unlink(jsonPath) +unlink(jsonPathNa) + +sparkR.session.stop() From 60262bc951864a7a3874ab3570b723198e99d613 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Sun, 19 Mar 2017 10:30:34 -0700 Subject: [PATCH 067/512] [MINOR][R] Reorder `Collate` fields in DESCRIPTION file ## What changes were proposed in this pull request? It seems cran check scripts corrects `R/pkg/DESCRIPTION` and follows the order in `Collate` fields. This PR proposes to fix this so that running this script does not show up a diff in this file. ## How was this patch tested? Manually via `./R/check-cran.sh`. Author: hyukjinkwon Closes #17349 from HyukjinKwon/minor-cran. --- R/pkg/DESCRIPTION | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/pkg/DESCRIPTION b/R/pkg/DESCRIPTION index 1635f71489aa..2ea90f7d3666 100644 --- a/R/pkg/DESCRIPTION +++ b/R/pkg/DESCRIPTION @@ -51,9 +51,9 @@ Collate: 'serialize.R' 'sparkR.R' 'stats.R' + 'streaming.R' 'types.R' 'utils.R' 'window.R' - 'streaming.R' RoxygenNote: 5.0.1 VignetteBuilder: knitr From 422aa67d1bb84f913b06e6d94615adb6557e2870 Mon Sep 17 00:00:00 2001 From: Felix Cheung Date: Sun, 19 Mar 2017 10:37:15 -0700 Subject: [PATCH 068/512] [SPARK-18817][SPARKR][SQL] change derby log output to temp dir ## What changes were proposed in this pull request? Passes R `tempdir()` (this is the R session temp dir, shared with other temp files/dirs) to JVM, set System.Property for derby home dir to move derby.log ## How was this patch tested? Manually, unit tests With this, these are relocated to under /tmp ``` # ls /tmp/RtmpG2M0cB/ derby.log ``` And they are removed automatically when the R session is ended. Author: Felix Cheung Closes #16330 from felixcheung/rderby. --- R/pkg/R/sparkR.R | 15 +++++++- R/pkg/inst/tests/testthat/test_sparkSQL.R | 34 +++++++++++++++++++ R/pkg/tests/run-all.R | 6 ++++ .../scala/org/apache/spark/api/r/RRDD.scala | 9 +++++ 4 files changed, 63 insertions(+), 1 deletion(-) diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R index 61773ed3ee8c..d0a12b7ecec6 100644 --- a/R/pkg/R/sparkR.R +++ b/R/pkg/R/sparkR.R @@ -322,10 +322,19 @@ sparkRHive.init <- function(jsc = NULL) { #' SparkSession or initializes a new SparkSession. #' Additional Spark properties can be set in \code{...}, and these named parameters take priority #' over values in \code{master}, \code{appName}, named lists of \code{sparkConfig}. -#' When called in an interactive session, this checks for the Spark installation, and, if not +#' +#' When called in an interactive session, this method checks for the Spark installation, and, if not #' found, it will be downloaded and cached automatically. Alternatively, \code{install.spark} can #' be called manually. #' +#' A default warehouse is created automatically in the current directory when a managed table is +#' created via \code{sql} statement \code{CREATE TABLE}, for example. To change the location of the +#' warehouse, set the named parameter \code{spark.sql.warehouse.dir} to the SparkSession. Along with +#' the warehouse, an accompanied metastore may also be automatically created in the current +#' directory when a new SparkSession is initialized with \code{enableHiveSupport} set to +#' \code{TRUE}, which is the default. For more details, refer to Hive configuration at +#' \url{http://spark.apache.org/docs/latest/sql-programming-guide.html#hive-tables}. +#' #' For details on how to initialize and use SparkR, refer to SparkR programming guide at #' \url{http://spark.apache.org/docs/latest/sparkr.html#starting-up-sparksession}. #' @@ -381,6 +390,10 @@ sparkR.session <- function( deployMode <- sparkConfigMap[["spark.submit.deployMode"]] } + if (!exists("spark.r.sql.derby.temp.dir", envir = sparkConfigMap)) { + sparkConfigMap[["spark.r.sql.derby.temp.dir"]] <- tempdir() + } + if (!exists(".sparkRjsc", envir = .sparkREnv)) { retHome <- sparkCheckInstall(sparkHome, master, deployMode) if (!is.null(retHome)) sparkHome <- retHome diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index f7081cb1d4e5..32856b399cdd 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -60,6 +60,7 @@ unsetHiveContext <- function() { # Tests for SparkSQL functions in SparkR +filesBefore <- list.files(path = sparkRDir, all.files = TRUE) sparkSession <- sparkR.session() sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession) @@ -2909,6 +2910,39 @@ test_that("Collect on DataFrame when NAs exists at the top of a timestamp column expect_equal(class(ldf3$col3), c("POSIXct", "POSIXt")) }) +compare_list <- function(list1, list2) { + # get testthat to show the diff by first making the 2 lists equal in length + expect_equal(length(list1), length(list2)) + l <- max(length(list1), length(list2)) + length(list1) <- l + length(list2) <- l + expect_equal(sort(list1, na.last = TRUE), sort(list2, na.last = TRUE)) +} + +# This should always be the **very last test** in this test file. +test_that("No extra files are created in SPARK_HOME by starting session and making calls", { + # Check that it is not creating any extra file. + # Does not check the tempdir which would be cleaned up after. + filesAfter <- list.files(path = sparkRDir, all.files = TRUE) + + expect_true(length(sparkRFilesBefore) > 0) + # first, ensure derby.log is not there + expect_false("derby.log" %in% filesAfter) + # second, ensure only spark-warehouse is created when calling SparkSession, enableHiveSupport = F + # note: currently all other test files have enableHiveSupport = F, so we capture the list of files + # before creating a SparkSession with enableHiveSupport = T at the top of this test file + # (filesBefore). The test here is to compare that (filesBefore) against the list of files before + # any test is run in run-all.R (sparkRFilesBefore). + # sparkRWhitelistSQLDirs is also defined in run-all.R, and should contain only 2 whitelisted dirs, + # here allow the first value, spark-warehouse, in the diff, everything else should be exactly the + # same as before any test is run. + compare_list(sparkRFilesBefore, setdiff(filesBefore, sparkRWhitelistSQLDirs[[1]])) + # third, ensure only spark-warehouse and metastore_db are created when enableHiveSupport = T + # note: as the note above, after running all tests in this file while enableHiveSupport = T, we + # check the list of files again. This time we allow both whitelisted dirs to be in the diff. + compare_list(sparkRFilesBefore, setdiff(filesAfter, sparkRWhitelistSQLDirs)) +}) + unlink(parquetPath) unlink(orcPath) unlink(jsonPath) diff --git a/R/pkg/tests/run-all.R b/R/pkg/tests/run-all.R index ab8d1ca01994..cefaadda6e21 100644 --- a/R/pkg/tests/run-all.R +++ b/R/pkg/tests/run-all.R @@ -22,6 +22,12 @@ library(SparkR) options("warn" = 2) # Setup global test environment +sparkRDir <- file.path(Sys.getenv("SPARK_HOME"), "R") +sparkRFilesBefore <- list.files(path = sparkRDir, all.files = TRUE) +sparkRWhitelistSQLDirs <- c("spark-warehouse", "metastore_db") +invisible(lapply(sparkRWhitelistSQLDirs, + function(x) { unlink(file.path(sparkRDir, x), recursive = TRUE, force = TRUE)})) + install.spark() test_package("SparkR") diff --git a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala index a1a5eb8cf55e..72ae0340aa3d 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala @@ -17,6 +17,7 @@ package org.apache.spark.api.r +import java.io.File import java.util.{Map => JMap} import scala.collection.JavaConverters._ @@ -127,6 +128,14 @@ private[r] object RRDD { sparkConf.setExecutorEnv(name.toString, value.toString) } + if (sparkEnvirMap.containsKey("spark.r.sql.derby.temp.dir") && + System.getProperty("derby.stream.error.file") == null) { + // This must be set before SparkContext is instantiated. + System.setProperty("derby.stream.error.file", + Seq(sparkEnvirMap.get("spark.r.sql.derby.temp.dir").toString, "derby.log") + .mkString(File.separator)) + } + val jsc = new JavaSparkContext(sparkConf) jars.foreach { jar => jsc.addJar(jar) From 0ee9fbf51ac863e015d57ae7824a39bd3b36141a Mon Sep 17 00:00:00 2001 From: Xiao Li Date: Sun, 19 Mar 2017 13:52:22 -0700 Subject: [PATCH 069/512] [SPARK-19990][TEST] Use the database after Hive's current Database is dropped ### What changes were proposed in this pull request? This PR is to fix the following test failure in maven and the PR https://github.com/apache/spark/pull/15363. > org.apache.spark.sql.hive.orc.OrcSourceSuite SPARK-19459/SPARK-18220: read char/varchar column written by Hive The[ test history](https://spark-tests.appspot.com/test-details?suite_name=org.apache.spark.sql.hive.orc.OrcSourceSuite&test_name=SPARK-19459%2FSPARK-18220%3A+read+char%2Fvarchar+column+written+by+Hive) shows all the maven builds failed this test case with the same error message. ``` FAILED: SemanticException [Error 10072]: Database does not exist: db2 org.apache.spark.sql.execution.QueryExecutionException: FAILED: SemanticException [Error 10072]: Database does not exist: db2 at org.apache.spark.sql.hive.client.HiveClientImpl$$anonfun$runHive$1.apply(HiveClientImpl.scala:637) at org.apache.spark.sql.hive.client.HiveClientImpl$$anonfun$runHive$1.apply(HiveClientImpl.scala:621) at org.apache.spark.sql.hive.client.HiveClientImpl$$anonfun$withHiveState$1.apply(HiveClientImpl.scala:288) at org.apache.spark.sql.hive.client.HiveClientImpl.liftedTree1$1(HiveClientImpl.scala:229) at org.apache.spark.sql.hive.client.HiveClientImpl.retryLocked(HiveClientImpl.scala:228) at org.apache.spark.sql.hive.client.HiveClientImpl.withHiveState(HiveClientImpl.scala:271) at org.apache.spark.sql.hive.client.HiveClientImpl.runHive(HiveClientImpl.scala:621) at org.apache.spark.sql.hive.client.HiveClientImpl.runSqlHive(HiveClientImpl.scala:611) at org.apache.spark.sql.hive.orc.OrcSuite$$anonfun$7.apply$mcV$sp(OrcSourceSuite.scala:160) at org.apache.spark.sql.hive.orc.OrcSuite$$anonfun$7.apply(OrcSourceSuite.scala:155) at org.apache.spark.sql.hive.orc.OrcSuite$$anonfun$7.apply(OrcSourceSuite.scala:155) at org.scalatest.Transformer$$anonfun$apply$1.apply$mcV$sp(Transformer.scala:22) at org.scalatest.OutcomeOf$class.outcomeOf(OutcomeOf.scala:85) at org.scalatest.OutcomeOf$.outcomeOf(OutcomeOf.scala:104) at org.scalatest.Transformer.apply(Transformer.scala:22) at org.scalatest.Transformer.apply(Transformer.scala:20) at org.scalatest.FunSuiteLike$$anon$1.apply(FunSuiteLike.scala:166) at org.apache.spark.SparkFunSuite.withFixture(SparkFunSuite.scala:68) at org.scalatest.FunSuiteLike$class.invokeWithFixture$1(FunSuiteLike.scala:163) at org.scalatest.FunSuiteLike$$anonfun$runTest$1.apply(FunSuiteLike.scala:175) ``` ### How was this patch tested? N/A Author: Xiao Li Closes #17344 from gatorsmile/testtest. --- .../spark/sql/hive/orc/OrcSourceSuite.scala | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala index 11dda5425cf9..6bfb88c0c1af 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala @@ -157,19 +157,21 @@ abstract class OrcSuite extends QueryTest with TestHiveSingleton with BeforeAndA val location = Utils.createTempDir() val uri = location.toURI try { + hiveClient.runSqlHive("USE default") hiveClient.runSqlHive( """ - |CREATE EXTERNAL TABLE hive_orc( - | a STRING, - | b CHAR(10), - | c VARCHAR(10), - | d ARRAY) - |STORED AS orc""".stripMargin) + |CREATE EXTERNAL TABLE hive_orc( + | a STRING, + | b CHAR(10), + | c VARCHAR(10), + | d ARRAY) + |STORED AS orc""".stripMargin) // Hive throws an exception if I assign the location in the create table statement. hiveClient.runSqlHive( s"ALTER TABLE hive_orc SET LOCATION '$uri'") hiveClient.runSqlHive( - """INSERT INTO TABLE hive_orc + """ + |INSERT INTO TABLE hive_orc |SELECT 'a', 'b', 'c', ARRAY(CAST('d' AS CHAR(3))) |FROM (SELECT 1) t""".stripMargin) From 990af630d0d569880edd9c7ce9932e10037a28ab Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Sun, 19 Mar 2017 14:07:49 -0700 Subject: [PATCH 070/512] [SPARK-19067][SS] Processing-time-based timeout in MapGroupsWithState ## What changes were proposed in this pull request? When a key does not get any new data in `mapGroupsWithState`, the mapping function is never called on it. So we need a timeout feature that calls the function again in such cases, so that the user can decide whether to continue waiting or clean up (remove state, save stuff externally, etc.). Timeouts can be either based on processing time or event time. This JIRA is for processing time, but defines the high level API design for both. The usage would look like this. ``` def stateFunction(key: K, value: Iterator[V], state: KeyedState[S]): U = { ... state.setTimeoutDuration(10000) ... } dataset // type is Dataset[T] .groupByKey[K](keyingFunc) // generates KeyValueGroupedDataset[K, T] .mapGroupsWithState[S, U]( func = stateFunction, timeout = KeyedStateTimeout.withProcessingTime) // returns Dataset[U] ``` Note the following design aspects. - The timeout type is provided as a param in mapGroupsWithState as a parameter global to all the keys. This is so that the planner knows this at planning time, and accordingly optimize the execution based on whether to saves extra info in state or not (e.g. timeout durations or timestamps). - The exact timeout duration is provided inside the function call so that it can be customized on a per key basis. - When the timeout occurs for a key, the function is called with no values, and KeyedState.isTimingOut() set to true. - The timeout is reset for key every time the function is called on the key, that is, when the key has new data, or the key has timed out. So the user has to set the timeout duration everytime the function is called, otherwise there will not be any timeout set. Guarantees provided on timeout of key, when timeout duration is D ms: - Timeout will never be called before real clock time has advanced by D ms - Timeout will be called eventually when there is a trigger with any data in it (i.e. after D ms). So there is a no strict upper bound on when the timeout would occur. For example, if there is no data in the stream (for any key) for a while, then the timeout will not be hit. Implementation details: - Added new param to `mapGroupsWithState` for timeout - Added new method to `StateStore` to filter data based on timeout timestamp - Changed the internal map type of `HDFSBackedStateStore` from Java's `HashMap` to `ConcurrentHashMap` as the latter allows weakly-consistent fail-safe iterators on the map data. See comments in code for more details. - Refactored logic of `MapGroupsWithStateExec` to - Save timeout info to state store for each key that has data. - Then, filter states that should be timed out based on the current batch processing timestamp. - Moved KeyedState for `o.a.s.sql` to `o.a.s.sql.streaming`. I remember that this was a feedback in the MapGroupsWithState PR that I had forgotten to address. ## How was this patch tested? New unit tests in - MapGroupsWithStateSuite for timeouts. - StateStoreSuite for new APIs in StateStore. Author: Tathagata Das Closes #17179 from tdas/mapgroupwithstate-timeout. --- .../sql/streaming/KeyedStateTimeout.java | 42 ++ .../expressions/objects/objects.scala | 2 +- .../sql/catalyst/plans/logical/object.scala | 30 +- .../streaming/JavaKeyedStateTimeoutSuite.java | 29 + .../analysis/UnsupportedOperationsSuite.scala | 80 +-- .../FlatMapGroupsWithStateFunction.java | 2 +- .../function/MapGroupsWithStateFunction.java | 2 +- .../spark/sql/KeyValueGroupedDataset.scala | 137 +++-- .../org/apache/spark/sql/KeyedState.scala | 140 ----- .../spark/sql/execution/SparkStrategies.scala | 20 +- .../sql/execution/command/commands.scala | 5 +- .../FlatMapGroupsWithStateExec.scala | 258 +++++++++ .../streaming/IncrementalExecution.scala | 16 +- .../execution/streaming/KeyedStateImpl.scala | 104 +++- .../execution/streaming/StreamExecution.scala | 2 +- .../state/HDFSBackedStateStoreProvider.scala | 19 +- .../streaming/state/StateStore.scala | 9 + .../streaming/statefulOperators.scala | 97 +--- .../spark/sql/streaming/KeyedState.scala | 214 +++++++ .../apache/spark/sql/JavaDatasetSuite.java | 4 +- .../streaming/state/StateStoreSuite.scala | 24 + .../FlatMapGroupsWithStateSuite.scala | 546 ++++++++++++++++-- 22 files changed, 1353 insertions(+), 429 deletions(-) create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/streaming/KeyedStateTimeout.java create mode 100644 sql/catalyst/src/test/java/org/apache/spark/sql/streaming/JavaKeyedStateTimeoutSuite.java delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/KeyedState.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/streaming/KeyedState.scala diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/KeyedStateTimeout.java b/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/KeyedStateTimeout.java new file mode 100644 index 000000000000..cf112f2e02a9 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/KeyedStateTimeout.java @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming; + +import org.apache.spark.annotation.Experimental; +import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.catalyst.plans.logical.NoTimeout$; +import org.apache.spark.sql.catalyst.plans.logical.ProcessingTimeTimeout; +import org.apache.spark.sql.catalyst.plans.logical.ProcessingTimeTimeout$; + +/** + * Represents the type of timeouts possible for the Dataset operations + * `mapGroupsWithState` and `flatMapGroupsWithState`. See documentation on + * `KeyedState` for more details. + * + * @since 2.2.0 + */ +@Experimental +@InterfaceStability.Evolving +public class KeyedStateTimeout { + + /** Timeout based on processing time. */ + public static KeyedStateTimeout ProcessingTimeTimeout() { return ProcessingTimeTimeout$.MODULE$; } + + /** No timeout */ + public static KeyedStateTimeout NoTimeout() { return NoTimeout$.MODULE$; } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 36bf3017d4cd..771ac28e5107 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -951,7 +951,7 @@ case class AssertNotNull(child: Expression, walkedTypePath: Seq[String]) override def eval(input: InternalRow): Any = { val result = child.eval(input) if (result == null) { - throw new RuntimeException(errMsg); + throw new RuntimeException(errMsg) } result } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala index 7f4462e58360..d1f95faf2db0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.analysis.UnresolvedDeserializer import org.apache.spark.sql.catalyst.encoders._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.objects.Invoke -import org.apache.spark.sql.streaming.OutputMode +import org.apache.spark.sql.streaming.{KeyedStateTimeout, OutputMode } import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -353,6 +353,10 @@ case class MapGroups( /** Internal class representing State */ trait LogicalKeyedState[S] +/** Possible types of timeouts used in FlatMapGroupsWithState */ +case object NoTimeout extends KeyedStateTimeout +case object ProcessingTimeTimeout extends KeyedStateTimeout + /** Factory for constructing new `MapGroupsWithState` nodes. */ object FlatMapGroupsWithState { def apply[K: Encoder, V: Encoder, S: Encoder, U: Encoder]( @@ -361,7 +365,10 @@ object FlatMapGroupsWithState { dataAttributes: Seq[Attribute], outputMode: OutputMode, isMapGroupsWithState: Boolean, + timeout: KeyedStateTimeout, child: LogicalPlan): LogicalPlan = { + val encoder = encoderFor[S] + val mapped = new FlatMapGroupsWithState( func, UnresolvedDeserializer(encoderFor[K].deserializer, groupingAttributes), @@ -369,11 +376,11 @@ object FlatMapGroupsWithState { groupingAttributes, dataAttributes, CatalystSerde.generateObjAttr[U], - encoderFor[S].resolveAndBind().deserializer, - encoderFor[S].namedExpressions, + encoder.asInstanceOf[ExpressionEncoder[Any]], outputMode, - child, - isMapGroupsWithState) + isMapGroupsWithState, + timeout, + child) CatalystSerde.serialize[U](mapped) } } @@ -384,15 +391,16 @@ object FlatMapGroupsWithState { * Func is invoked with an object representation of the grouping key an iterator containing the * object representation of all the rows with that key. * + * @param func function called on each group * @param keyDeserializer used to extract the key object for each group. * @param valueDeserializer used to extract the items in the iterator from an input row. * @param groupingAttributes used to group the data * @param dataAttributes used to read the data * @param outputObjAttr used to define the output object - * @param stateDeserializer used to deserialize state before calling `func` - * @param stateSerializer used to serialize updated state after calling `func` + * @param stateEncoder used to serialize/deserialize state before calling `func` * @param outputMode the output mode of `func` * @param isMapGroupsWithState whether it is created by the `mapGroupsWithState` method + * @param timeout used to timeout groups that have not received data in a while */ case class FlatMapGroupsWithState( func: (Any, Iterator[Any], LogicalKeyedState[Any]) => Iterator[Any], @@ -401,11 +409,11 @@ case class FlatMapGroupsWithState( groupingAttributes: Seq[Attribute], dataAttributes: Seq[Attribute], outputObjAttr: Attribute, - stateDeserializer: Expression, - stateSerializer: Seq[NamedExpression], + stateEncoder: ExpressionEncoder[Any], outputMode: OutputMode, - child: LogicalPlan, - isMapGroupsWithState: Boolean = false) extends UnaryNode with ObjectProducer { + isMapGroupsWithState: Boolean = false, + timeout: KeyedStateTimeout, + child: LogicalPlan) extends UnaryNode with ObjectProducer { if (isMapGroupsWithState) { assert(outputMode == OutputMode.Update) diff --git a/sql/catalyst/src/test/java/org/apache/spark/sql/streaming/JavaKeyedStateTimeoutSuite.java b/sql/catalyst/src/test/java/org/apache/spark/sql/streaming/JavaKeyedStateTimeoutSuite.java new file mode 100644 index 000000000000..02c94b0b3244 --- /dev/null +++ b/sql/catalyst/src/test/java/org/apache/spark/sql/streaming/JavaKeyedStateTimeoutSuite.java @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming; + +import org.apache.spark.sql.catalyst.plans.logical.ProcessingTimeTimeout$; +import org.junit.Test; + +public class JavaKeyedStateTimeoutSuite { + + @Test + public void testTimeouts() { + assert(KeyedStateTimeout.ProcessingTimeTimeout() == ProcessingTimeTimeout$.MODULE$); + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala index 200c39f43a6b..08216e266040 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala @@ -144,14 +144,16 @@ class UnsupportedOperationsSuite extends SparkFunSuite { assertSupportedInBatchPlan( s"flatMapGroupsWithState - flatMapGroupsWithState($funcMode) on batch relation", FlatMapGroupsWithState( - null, att, att, Seq(att), Seq(att), att, att, Seq(att), funcMode, batchRelation)) + null, att, att, Seq(att), Seq(att), att, null, funcMode, isMapGroupsWithState = false, null, + batchRelation)) assertSupportedInBatchPlan( s"flatMapGroupsWithState - multiple flatMapGroupsWithState($funcMode)s on batch relation", FlatMapGroupsWithState( - null, att, att, Seq(att), Seq(att), att, att, Seq(att), funcMode, + null, att, att, Seq(att), Seq(att), att, null, funcMode, isMapGroupsWithState = false, null, FlatMapGroupsWithState( - null, att, att, Seq(att), Seq(att), att, att, Seq(att), funcMode, batchRelation))) + null, att, att, Seq(att), Seq(att), att, null, funcMode, isMapGroupsWithState = false, + null, batchRelation))) } // FlatMapGroupsWithState(Update) in streaming without aggregation @@ -159,14 +161,16 @@ class UnsupportedOperationsSuite extends SparkFunSuite { "flatMapGroupsWithState - flatMapGroupsWithState(Update) " + "on streaming relation without aggregation in update mode", FlatMapGroupsWithState( - null, att, att, Seq(att), Seq(att), att, att, Seq(att), Update, streamRelation), + null, att, att, Seq(att), Seq(att), att, null, Update, isMapGroupsWithState = false, null, + streamRelation), outputMode = Update) assertNotSupportedInStreamingPlan( "flatMapGroupsWithState - flatMapGroupsWithState(Update) " + "on streaming relation without aggregation in append mode", FlatMapGroupsWithState( - null, att, att, Seq(att), Seq(att), att, att, Seq(att), Update, streamRelation), + null, att, att, Seq(att), Seq(att), att, null, Update, isMapGroupsWithState = false, null, + streamRelation), outputMode = Append, expectedMsgs = Seq("flatMapGroupsWithState in update mode", "Append")) @@ -174,7 +178,8 @@ class UnsupportedOperationsSuite extends SparkFunSuite { "flatMapGroupsWithState - flatMapGroupsWithState(Update) " + "on streaming relation without aggregation in complete mode", FlatMapGroupsWithState( - null, att, att, Seq(att), Seq(att), att, att, Seq(att), Update, streamRelation), + null, att, att, Seq(att), Seq(att), att, null, Update, isMapGroupsWithState = false, null, + streamRelation), outputMode = Complete, // Disallowed by the aggregation check but let's still keep this test in case it's broken in // future. @@ -186,7 +191,7 @@ class UnsupportedOperationsSuite extends SparkFunSuite { "flatMapGroupsWithState - flatMapGroupsWithState(Update) on streaming relation " + s"with aggregation in $outputMode mode", FlatMapGroupsWithState( - null, att, att, Seq(att), Seq(att), att, att, Seq(att), Update, + null, att, att, Seq(att), Seq(att), att, null, Update, isMapGroupsWithState = false, null, Aggregate(Seq(attributeWithWatermark), aggExprs("c"), streamRelation)), outputMode = outputMode, expectedMsgs = Seq("flatMapGroupsWithState in update mode", "with aggregation")) @@ -197,14 +202,16 @@ class UnsupportedOperationsSuite extends SparkFunSuite { "flatMapGroupsWithState - flatMapGroupsWithState(Append) " + "on streaming relation without aggregation in append mode", FlatMapGroupsWithState( - null, att, att, Seq(att), Seq(att), att, att, Seq(att), Append, streamRelation), + null, att, att, Seq(att), Seq(att), att, null, Append, isMapGroupsWithState = false, null, + streamRelation), outputMode = Append) assertNotSupportedInStreamingPlan( "flatMapGroupsWithState - flatMapGroupsWithState(Append) " + "on streaming relation without aggregation in update mode", FlatMapGroupsWithState( - null, att, att, Seq(att), Seq(att), att, att, Seq(att), Append, streamRelation), + null, att, att, Seq(att), Seq(att), att, null, Append, isMapGroupsWithState = false, null, + streamRelation), outputMode = Update, expectedMsgs = Seq("flatMapGroupsWithState in append mode", "update")) @@ -217,7 +224,8 @@ class UnsupportedOperationsSuite extends SparkFunSuite { Seq(attributeWithWatermark), aggExprs("c"), FlatMapGroupsWithState( - null, att, att, Seq(att), Seq(att), att, att, Seq(att), Append, streamRelation)), + null, att, att, Seq(att), Seq(att), att, null, Append, isMapGroupsWithState = false, null, + streamRelation)), outputMode = outputMode) } @@ -225,7 +233,8 @@ class UnsupportedOperationsSuite extends SparkFunSuite { assertNotSupportedInStreamingPlan( "flatMapGroupsWithState - flatMapGroupsWithState(Append) " + s"on streaming relation after aggregation in $outputMode mode", - FlatMapGroupsWithState(null, att, att, Seq(att), Seq(att), att, att, Seq(att), Append, + FlatMapGroupsWithState(null, att, att, Seq(att), Seq(att), att, null, Append, + isMapGroupsWithState = false, null, Aggregate(Seq(attributeWithWatermark), aggExprs("c"), streamRelation)), outputMode = outputMode, expectedMsgs = Seq("flatMapGroupsWithState", "after aggregation")) @@ -235,7 +244,8 @@ class UnsupportedOperationsSuite extends SparkFunSuite { "flatMapGroupsWithState - " + "flatMapGroupsWithState(Update) on streaming relation in complete mode", FlatMapGroupsWithState( - null, att, att, Seq(att), Seq(att), att, att, Seq(att), Append, streamRelation), + null, att, att, Seq(att), Seq(att), att, null, Append, isMapGroupsWithState = false, null, + streamRelation), outputMode = Complete, // Disallowed by the aggregation check but let's still keep this test in case it's broken in // future. @@ -248,7 +258,8 @@ class UnsupportedOperationsSuite extends SparkFunSuite { s"flatMapGroupsWithState - flatMapGroupsWithState($funcMode) on batch relation inside " + s"streaming relation in $outputMode output mode", FlatMapGroupsWithState( - null, att, att, Seq(att), Seq(att), att, att, Seq(att), funcMode, batchRelation), + null, att, att, Seq(att), Seq(att), att, null, funcMode, isMapGroupsWithState = false, + null, batchRelation), outputMode = outputMode ) } @@ -258,19 +269,20 @@ class UnsupportedOperationsSuite extends SparkFunSuite { assertSupportedInStreamingPlan( "flatMapGroupsWithState - multiple flatMapGroupsWithStates on streaming relation and all are " + "in append mode", - FlatMapGroupsWithState( - null, att, att, Seq(att), Seq(att), att, att, Seq(att), Append, - FlatMapGroupsWithState( - null, att, att, Seq(att), Seq(att), att, att, Seq(att), Append, streamRelation)), + FlatMapGroupsWithState(null, att, att, Seq(att), Seq(att), att, null, Append, + isMapGroupsWithState = false, null, + FlatMapGroupsWithState(null, att, att, Seq(att), Seq(att), att, null, Append, + isMapGroupsWithState = false, null, streamRelation)), outputMode = Append) assertNotSupportedInStreamingPlan( "flatMapGroupsWithState - multiple flatMapGroupsWithStates on s streaming relation but some" + " are not in append mode", FlatMapGroupsWithState( - null, att, att, Seq(att), Seq(att), att, att, Seq(att), Update, + null, att, att, Seq(att), Seq(att), att, null, Update, isMapGroupsWithState = false, null, FlatMapGroupsWithState( - null, att, att, Seq(att), Seq(att), att, att, Seq(att), Append, streamRelation)), + null, att, att, Seq(att), Seq(att), att, null, Append, isMapGroupsWithState = false, null, + streamRelation)), outputMode = Append, expectedMsgs = Seq("multiple flatMapGroupsWithState", "append")) @@ -279,8 +291,8 @@ class UnsupportedOperationsSuite extends SparkFunSuite { "mapGroupsWithState - mapGroupsWithState " + "on streaming relation without aggregation in append mode", FlatMapGroupsWithState( - null, att, att, Seq(att), Seq(att), att, att, Seq(att), Update, streamRelation, - isMapGroupsWithState = true), + null, att, att, Seq(att), Seq(att), att, null, Update, isMapGroupsWithState = true, null, + streamRelation), outputMode = Append, // Disallowed by the aggregation check but let's still keep this test in case it's broken in // future. @@ -290,8 +302,8 @@ class UnsupportedOperationsSuite extends SparkFunSuite { "mapGroupsWithState - mapGroupsWithState " + "on streaming relation without aggregation in complete mode", FlatMapGroupsWithState( - null, att, att, Seq(att), Seq(att), att, att, Seq(att), Update, streamRelation, - isMapGroupsWithState = true), + null, att, att, Seq(att), Seq(att), att, null, Update, isMapGroupsWithState = true, null, + streamRelation), outputMode = Complete, // Disallowed by the aggregation check but let's still keep this test in case it's broken in // future. @@ -301,10 +313,9 @@ class UnsupportedOperationsSuite extends SparkFunSuite { assertNotSupportedInStreamingPlan( "mapGroupsWithState - mapGroupsWithState on streaming relation " + s"with aggregation in $outputMode mode", - FlatMapGroupsWithState( - null, att, att, Seq(att), Seq(att), att, att, Seq(att), Update, - Aggregate(Seq(attributeWithWatermark), aggExprs("c"), streamRelation), - isMapGroupsWithState = true), + FlatMapGroupsWithState(null, att, att, Seq(att), Seq(att), att, null, Update, + isMapGroupsWithState = true, null, + Aggregate(Seq(attributeWithWatermark), aggExprs("c"), streamRelation)), outputMode = outputMode, expectedMsgs = Seq("mapGroupsWithState", "with aggregation")) } @@ -314,11 +325,10 @@ class UnsupportedOperationsSuite extends SparkFunSuite { "mapGroupsWithState - multiple mapGroupsWithStates on streaming relation and all are " + "in append mode", FlatMapGroupsWithState( - null, att, att, Seq(att), Seq(att), att, att, Seq(att), Update, + null, att, att, Seq(att), Seq(att), att, null, Update, isMapGroupsWithState = true, null, FlatMapGroupsWithState( - null, att, att, Seq(att), Seq(att), att, att, Seq(att), Update, streamRelation, - isMapGroupsWithState = true), - isMapGroupsWithState = true), + null, att, att, Seq(att), Seq(att), att, null, Update, isMapGroupsWithState = true, null, + streamRelation)), outputMode = Append, expectedMsgs = Seq("multiple mapGroupsWithStates")) @@ -327,11 +337,11 @@ class UnsupportedOperationsSuite extends SparkFunSuite { "mapGroupsWithState - " + "mixing mapGroupsWithStates and flatMapGroupsWithStates on streaming relation", FlatMapGroupsWithState( - null, att, att, Seq(att), Seq(att), att, att, Seq(att), Update, + null, att, att, Seq(att), Seq(att), att, null, Update, isMapGroupsWithState = true, null, FlatMapGroupsWithState( - null, att, att, Seq(att), Seq(att), att, att, Seq(att), Update, streamRelation, - isMapGroupsWithState = false), - isMapGroupsWithState = true), + null, att, att, Seq(att), Seq(att), att, null, Update, isMapGroupsWithState = false, null, + streamRelation) + ), outputMode = Append, expectedMsgs = Seq("Mixing mapGroupsWithStates and flatMapGroupsWithStates")) diff --git a/sql/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsWithStateFunction.java b/sql/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsWithStateFunction.java index d44af7ef4815..29af78c4f6a8 100644 --- a/sql/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsWithStateFunction.java +++ b/sql/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsWithStateFunction.java @@ -22,7 +22,7 @@ import org.apache.spark.annotation.Experimental; import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.KeyedState; +import org.apache.spark.sql.streaming.KeyedState; /** * ::Experimental:: diff --git a/sql/core/src/main/java/org/apache/spark/api/java/function/MapGroupsWithStateFunction.java b/sql/core/src/main/java/org/apache/spark/api/java/function/MapGroupsWithStateFunction.java index 75986d170620..70f3f01a8e9d 100644 --- a/sql/core/src/main/java/org/apache/spark/api/java/function/MapGroupsWithStateFunction.java +++ b/sql/core/src/main/java/org/apache/spark/api/java/function/MapGroupsWithStateFunction.java @@ -22,7 +22,7 @@ import org.apache.spark.annotation.Experimental; import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.KeyedState; +import org.apache.spark.sql.streaming.KeyedState; /** * ::Experimental:: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala index ab956ffd642e..96437f868a6e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.streaming.InternalOutputModes import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.expressions.ReduceAggregator -import org.apache.spark.sql.streaming.OutputMode +import org.apache.spark.sql.streaming.{KeyedState, KeyedStateTimeout, OutputMode} /** * :: Experimental :: @@ -228,13 +228,14 @@ class KeyValueGroupedDataset[K, V] private[sql]( * For a static batch Dataset, the function will be invoked once per group. For a streaming * Dataset, the function will be invoked for each group repeatedly in every trigger, and * updates to each group's state will be saved across invocations. - * See [[KeyedState]] for more details. + * See [[org.apache.spark.sql.streaming.KeyedState]] for more details. * * @tparam S The type of the user-defined state. Must be encodable to Spark SQL types. * @tparam U The type of the output objects. Must be encodable to Spark SQL types. + * @param func Function to be called on every group. * * See [[Encoder]] for more details on what types are encodable to Spark SQL. - * @since 2.1.1 + * @since 2.2.0 */ @Experimental @InterfaceStability.Evolving @@ -249,42 +250,49 @@ class KeyValueGroupedDataset[K, V] private[sql]( dataAttributes, OutputMode.Update, isMapGroupsWithState = true, + KeyedStateTimeout.NoTimeout, child = logicalPlan)) } /** * ::Experimental:: - * (Java-specific) + * (Scala-specific) * Applies the given function to each group of data, while maintaining a user-defined per-group * state. The result Dataset will represent the objects returned by the function. * For a static batch Dataset, the function will be invoked once per group. For a streaming * Dataset, the function will be invoked for each group repeatedly in every trigger, and * updates to each group's state will be saved across invocations. - * See [[KeyedState]] for more details. + * See [[org.apache.spark.sql.streaming.KeyedState]] for more details. * * @tparam S The type of the user-defined state. Must be encodable to Spark SQL types. * @tparam U The type of the output objects. Must be encodable to Spark SQL types. - * @param func Function to be called on every group. - * @param stateEncoder Encoder for the state type. - * @param outputEncoder Encoder for the output type. + * @param func Function to be called on every group. + * @param timeoutConf Timeout configuration for groups that do not receive data for a while. * * See [[Encoder]] for more details on what types are encodable to Spark SQL. - * @since 2.1.1 + * @since 2.2.0 */ @Experimental @InterfaceStability.Evolving - def mapGroupsWithState[S, U]( - func: MapGroupsWithStateFunction[K, V, S, U], - stateEncoder: Encoder[S], - outputEncoder: Encoder[U]): Dataset[U] = { - mapGroupsWithState[S, U]( - (key: K, it: Iterator[V], s: KeyedState[S]) => func.call(key, it.asJava, s) - )(stateEncoder, outputEncoder) + def mapGroupsWithState[S: Encoder, U: Encoder]( + timeoutConf: KeyedStateTimeout)( + func: (K, Iterator[V], KeyedState[S]) => U): Dataset[U] = { + val flatMapFunc = (key: K, it: Iterator[V], s: KeyedState[S]) => Iterator(func(key, it, s)) + Dataset[U]( + sparkSession, + FlatMapGroupsWithState[K, V, S, U]( + flatMapFunc.asInstanceOf[(Any, Iterator[Any], LogicalKeyedState[Any]) => Iterator[Any]], + groupingAttributes, + dataAttributes, + OutputMode.Update, + isMapGroupsWithState = true, + timeoutConf, + child = logicalPlan)) } /** * ::Experimental:: - * (Scala-specific) + * (Java-specific) * Applies the given function to each group of data, while maintaining a user-defined per-group * state. The result Dataset will represent the objects returned by the function. * For a static batch Dataset, the function will be invoked once per group. For a streaming @@ -294,33 +302,27 @@ class KeyValueGroupedDataset[K, V] private[sql]( * * @tparam S The type of the user-defined state. Must be encodable to Spark SQL types. * @tparam U The type of the output objects. Must be encodable to Spark SQL types. - * @param func Function to be called on every group. - * @param outputMode The output mode of the function. + * @param func Function to be called on every group. + * @param stateEncoder Encoder for the state type. + * @param outputEncoder Encoder for the output type. * * See [[Encoder]] for more details on what types are encodable to Spark SQL. - * @since 2.1.1 + * @since 2.2.0 */ @Experimental @InterfaceStability.Evolving - def flatMapGroupsWithState[S: Encoder, U: Encoder]( - func: (K, Iterator[V], KeyedState[S]) => Iterator[U], outputMode: OutputMode): Dataset[U] = { - if (outputMode != OutputMode.Append && outputMode != OutputMode.Update) { - throw new IllegalArgumentException("The output mode of function should be append or update") - } - Dataset[U]( - sparkSession, - FlatMapGroupsWithState[K, V, S, U]( - func.asInstanceOf[(Any, Iterator[Any], LogicalKeyedState[Any]) => Iterator[Any]], - groupingAttributes, - dataAttributes, - outputMode, - isMapGroupsWithState = false, - child = logicalPlan)) + def mapGroupsWithState[S, U]( + func: MapGroupsWithStateFunction[K, V, S, U], + stateEncoder: Encoder[S], + outputEncoder: Encoder[U]): Dataset[U] = { + mapGroupsWithState[S, U]( + (key: K, it: Iterator[V], s: KeyedState[S]) => func.call(key, it.asJava, s) + )(stateEncoder, outputEncoder) } /** * ::Experimental:: - * (Scala-specific) + * (Java-specific) * Applies the given function to each group of data, while maintaining a user-defined per-group * state. The result Dataset will represent the objects returned by the function. * For a static batch Dataset, the function will be invoked once per group. For a streaming @@ -330,22 +332,29 @@ class KeyValueGroupedDataset[K, V] private[sql]( * * @tparam S The type of the user-defined state. Must be encodable to Spark SQL types. * @tparam U The type of the output objects. Must be encodable to Spark SQL types. - * @param func Function to be called on every group. - * @param outputMode The output mode of the function. + * @param func Function to be called on every group. + * @param stateEncoder Encoder for the state type. + * @param outputEncoder Encoder for the output type. + * @param timeoutConf Timeout configuration for groups that do not receive data for a while. * * See [[Encoder]] for more details on what types are encodable to Spark SQL. - * @since 2.1.1 + * @since 2.2.0 */ @Experimental @InterfaceStability.Evolving - def flatMapGroupsWithState[S: Encoder, U: Encoder]( - func: (K, Iterator[V], KeyedState[S]) => Iterator[U], outputMode: String): Dataset[U] = { - flatMapGroupsWithState(func, InternalOutputModes(outputMode)) + def mapGroupsWithState[S, U]( + func: MapGroupsWithStateFunction[K, V, S, U], + stateEncoder: Encoder[S], + outputEncoder: Encoder[U], + timeoutConf: KeyedStateTimeout): Dataset[U] = { + mapGroupsWithState[S, U]( + (key: K, it: Iterator[V], s: KeyedState[S]) => func.call(key, it.asJava, s) + )(stateEncoder, outputEncoder) } /** * ::Experimental:: - * (Java-specific) + * (Scala-specific) * Applies the given function to each group of data, while maintaining a user-defined per-group * state. The result Dataset will represent the objects returned by the function. * For a static batch Dataset, the function will be invoked once per group. For a streaming @@ -355,25 +364,32 @@ class KeyValueGroupedDataset[K, V] private[sql]( * * @tparam S The type of the user-defined state. Must be encodable to Spark SQL types. * @tparam U The type of the output objects. Must be encodable to Spark SQL types. - * @param func Function to be called on every group. - * @param outputMode The output mode of the function. - * @param stateEncoder Encoder for the state type. - * @param outputEncoder Encoder for the output type. + * @param func Function to be called on every group. + * @param outputMode The output mode of the function. + * @param timeoutConf Timeout configuration for groups that do not receive data for a while. * * See [[Encoder]] for more details on what types are encodable to Spark SQL. - * @since 2.1.1 + * @since 2.2.0 */ @Experimental @InterfaceStability.Evolving - def flatMapGroupsWithState[S, U]( - func: FlatMapGroupsWithStateFunction[K, V, S, U], + def flatMapGroupsWithState[S: Encoder, U: Encoder]( outputMode: OutputMode, - stateEncoder: Encoder[S], - outputEncoder: Encoder[U]): Dataset[U] = { - flatMapGroupsWithState[S, U]( - (key: K, it: Iterator[V], s: KeyedState[S]) => func.call(key, it.asJava, s).asScala, - outputMode - )(stateEncoder, outputEncoder) + timeoutConf: KeyedStateTimeout)( + func: (K, Iterator[V], KeyedState[S]) => Iterator[U]): Dataset[U] = { + if (outputMode != OutputMode.Append && outputMode != OutputMode.Update) { + throw new IllegalArgumentException("The output mode of function should be append or update") + } + Dataset[U]( + sparkSession, + FlatMapGroupsWithState[K, V, S, U]( + func.asInstanceOf[(Any, Iterator[Any], LogicalKeyedState[Any]) => Iterator[Any]], + groupingAttributes, + dataAttributes, + outputMode, + isMapGroupsWithState = false, + timeoutConf, + child = logicalPlan)) } /** @@ -392,18 +408,21 @@ class KeyValueGroupedDataset[K, V] private[sql]( * @param outputMode The output mode of the function. * @param stateEncoder Encoder for the state type. * @param outputEncoder Encoder for the output type. + * @param timeoutConf Timeout configuration for groups that do not receive data for a while. * * See [[Encoder]] for more details on what types are encodable to Spark SQL. - * @since 2.1.1 + * @since 2.2.0 */ @Experimental @InterfaceStability.Evolving def flatMapGroupsWithState[S, U]( func: FlatMapGroupsWithStateFunction[K, V, S, U], - outputMode: String, + outputMode: OutputMode, stateEncoder: Encoder[S], - outputEncoder: Encoder[U]): Dataset[U] = { - flatMapGroupsWithState(func, InternalOutputModes(outputMode), stateEncoder, outputEncoder) + outputEncoder: Encoder[U], + timeoutConf: KeyedStateTimeout): Dataset[U] = { + val f = (key: K, it: Iterator[V], s: KeyedState[S]) => func.call(key, it.asJava, s).asScala + flatMapGroupsWithState[S, U](outputMode, timeoutConf)(f)(stateEncoder, outputEncoder) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyedState.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyedState.scala deleted file mode 100644 index 71efa4384211..000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/KeyedState.scala +++ /dev/null @@ -1,140 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql - -import org.apache.spark.annotation.{Experimental, InterfaceStability} -import org.apache.spark.sql.catalyst.plans.logical.LogicalKeyedState - -/** - * :: Experimental :: - * - * Wrapper class for interacting with keyed state data in `mapGroupsWithState` and - * `flatMapGroupsWithState` operations on - * [[KeyValueGroupedDataset]]. - * - * Detail description on `[map/flatMap]GroupsWithState` operation - * ------------------------------------------------------------ - * Both, `mapGroupsWithState` and `flatMapGroupsWithState` in [[KeyValueGroupedDataset]] - * will invoke the user-given function on each group (defined by the grouping function in - * `Dataset.groupByKey()`) while maintaining user-defined per-group state between invocations. - * For a static batch Dataset, the function will be invoked once per group. For a streaming - * Dataset, the function will be invoked for each group repeatedly in every trigger. - * That is, in every batch of the `streaming.StreamingQuery`, - * the function will be invoked once for each group that has data in the batch. - * - * The function is invoked with following parameters. - * - The key of the group. - * - An iterator containing all the values for this key. - * - A user-defined state object set by previous invocations of the given function. - * In case of a batch Dataset, there is only one invocation and state object will be empty as - * there is no prior state. Essentially, for batch Datasets, `[map/flatMap]GroupsWithState` - * is equivalent to `[map/flatMap]Groups`. - * - * Important points to note about the function. - * - In a trigger, the function will be called only the groups present in the batch. So do not - * assume that the function will be called in every trigger for every group that has state. - * - There is no guaranteed ordering of values in the iterator in the function, neither with - * batch, nor with streaming Datasets. - * - All the data will be shuffled before applying the function. - * - * Important points to note about using KeyedState. - * - The value of the state cannot be null. So updating state with null will throw - * `IllegalArgumentException`. - * - Operations on `KeyedState` are not thread-safe. This is to avoid memory barriers. - * - If `remove()` is called, then `exists()` will return `false`, - * `get()` will throw `NoSuchElementException` and `getOption()` will return `None` - * - After that, if `update(newState)` is called, then `exists()` will again return `true`, - * `get()` and `getOption()`will return the updated value. - * - * Scala example of using KeyedState in `mapGroupsWithState`: - * {{{ - * // A mapping function that maintains an integer state for string keys and returns a string. - * def mappingFunction(key: String, value: Iterator[Int], state: KeyedState[Int]): String = { - * // Check if state exists - * if (state.exists) { - * val existingState = state.get // Get the existing state - * val shouldRemove = ... // Decide whether to remove the state - * if (shouldRemove) { - * state.remove() // Remove the state - * } else { - * val newState = ... - * state.update(newState) // Set the new state - * } - * } else { - * val initialState = ... - * state.update(initialState) // Set the initial state - * } - * ... // return something - * } - * - * }}} - * - * Java example of using `KeyedState`: - * {{{ - * // A mapping function that maintains an integer state for string keys and returns a string. - * MapGroupsWithStateFunction mappingFunction = - * new MapGroupsWithStateFunction() { - * - * @Override - * public String call(String key, Iterator value, KeyedState state) { - * if (state.exists()) { - * int existingState = state.get(); // Get the existing state - * boolean shouldRemove = ...; // Decide whether to remove the state - * if (shouldRemove) { - * state.remove(); // Remove the state - * } else { - * int newState = ...; - * state.update(newState); // Set the new state - * } - * } else { - * int initialState = ...; // Set the initial state - * state.update(initialState); - * } - * ... // return something - * } - * }; - * }}} - * - * @tparam S User-defined type of the state to be stored for each key. Must be encodable into - * Spark SQL types (see [[Encoder]] for more details). - * @since 2.1.1 - */ -@Experimental -@InterfaceStability.Evolving -trait KeyedState[S] extends LogicalKeyedState[S] { - - /** Whether state exists or not. */ - def exists: Boolean - - /** Get the state value if it exists, or throw NoSuchElementException. */ - @throws[NoSuchElementException]("when state does not exist") - def get: S - - /** Get the state value as a scala Option. */ - def getOption: Option[S] - - /** - * Update the value of the state. Note that `null` is not a valid value, and it throws - * IllegalArgumentException. - */ - @throws[IllegalArgumentException]("when updating with null") - def update(newState: S): Unit - - /** Remove this keyed state. */ - def remove(): Unit -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 0f7aa3709c1c..9e58e8ce3d5f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -329,22 +329,14 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { * Strategy to convert [[FlatMapGroupsWithState]] logical operator to physical operator * in streaming plans. Conversion for batch plans is handled by [[BasicOperators]]. */ - object MapGroupsWithStateStrategy extends Strategy { + object FlatMapGroupsWithStateStrategy extends Strategy { override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case FlatMapGroupsWithState( - f, - keyDeser, - valueDeser, - groupAttr, - dataAttr, - outputAttr, - stateDeser, - stateSer, - outputMode, - child, - _) => + func, keyDeser, valueDeser, groupAttr, dataAttr, outputAttr, stateEnc, outputMode, _, + timeout, child) => val execPlan = FlatMapGroupsWithStateExec( - f, keyDeser, valueDeser, groupAttr, dataAttr, outputAttr, None, stateDeser, stateSer, + func, keyDeser, valueDeser, groupAttr, dataAttr, outputAttr, None, stateEnc, outputMode, + timeout, batchTimestampMs = KeyedStateImpl.NO_BATCH_PROCESSING_TIMESTAMP, planLater(child)) execPlan :: Nil case _ => @@ -392,7 +384,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case logical.MapGroups(f, key, value, grouping, data, objAttr, child) => execution.MapGroupsExec(f, key, value, grouping, data, objAttr, planLater(child)) :: Nil case logical.FlatMapGroupsWithState( - f, key, value, grouping, data, output, _, _, _, child, _) => + f, key, value, grouping, data, output, _, _, _, _, child) => execution.MapGroupsExec(f, key, value, grouping, data, output, planLater(child)) :: Nil case logical.CoGroup(f, key, lObj, rObj, lGroup, rGroup, lAttr, rAttr, oAttr, left, right) => execution.CoGroupExec( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala index 5de45b159684..41d91d877d4c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.debug._ -import org.apache.spark.sql.execution.streaming.IncrementalExecution +import org.apache.spark.sql.execution.streaming.{IncrementalExecution, OffsetSeqMetadata} import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types._ @@ -106,7 +106,8 @@ case class ExplainCommand( if (logicalPlan.isStreaming) { // This is used only by explaining `Dataset/DataFrame` created by `spark.readStream`, so the // output mode does not matter since there is no `Sink`. - new IncrementalExecution(sparkSession, logicalPlan, OutputMode.Append(), "", 0, 0) + new IncrementalExecution( + sparkSession, logicalPlan, OutputMode.Append(), "", 0, OffsetSeqMetadata(0, 0)) } else { sparkSession.sessionState.executePlan(logicalPlan) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala new file mode 100644 index 000000000000..991d8ef70756 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala @@ -0,0 +1,258 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.streaming + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, AttributeReference, Expression, Literal, SortOrder, SpecificInternalRow, UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.catalyst.plans.logical.{LogicalKeyedState, ProcessingTimeTimeout} +import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution, Partitioning} +import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.streaming.state._ +import org.apache.spark.sql.streaming.{KeyedStateTimeout, OutputMode} +import org.apache.spark.sql.types.{BooleanType, IntegerType} +import org.apache.spark.util.CompletionIterator + +/** + * Physical operator for executing `FlatMapGroupsWithState.` + * + * @param func function called on each group + * @param keyDeserializer used to extract the key object for each group. + * @param valueDeserializer used to extract the items in the iterator from an input row. + * @param groupingAttributes used to group the data + * @param dataAttributes used to read the data + * @param outputObjAttr used to define the output object + * @param stateEncoder used to serialize/deserialize state before calling `func` + * @param outputMode the output mode of `func` + * @param timeout used to timeout groups that have not received data in a while + * @param batchTimestampMs processing timestamp of the current batch. + */ +case class FlatMapGroupsWithStateExec( + func: (Any, Iterator[Any], LogicalKeyedState[Any]) => Iterator[Any], + keyDeserializer: Expression, + valueDeserializer: Expression, + groupingAttributes: Seq[Attribute], + dataAttributes: Seq[Attribute], + outputObjAttr: Attribute, + stateId: Option[OperatorStateId], + stateEncoder: ExpressionEncoder[Any], + outputMode: OutputMode, + timeout: KeyedStateTimeout, + batchTimestampMs: Long, + child: SparkPlan) extends UnaryExecNode with ObjectProducerExec with StateStoreWriter { + + private val isTimeoutEnabled = timeout == ProcessingTimeTimeout + private val timestampTimeoutAttribute = + AttributeReference("timeoutTimestamp", dataType = IntegerType, nullable = false)() + private val stateAttributes: Seq[Attribute] = { + val encSchemaAttribs = stateEncoder.schema.toAttributes + if (isTimeoutEnabled) encSchemaAttribs :+ timestampTimeoutAttribute else encSchemaAttribs + } + + import KeyedStateImpl._ + + /** Distribute by grouping attributes */ + override def requiredChildDistribution: Seq[Distribution] = + ClusteredDistribution(groupingAttributes) :: Nil + + /** Ordering needed for using GroupingIterator */ + override def requiredChildOrdering: Seq[Seq[SortOrder]] = + Seq(groupingAttributes.map(SortOrder(_, Ascending))) + + override protected def doExecute(): RDD[InternalRow] = { + metrics // force lazy init at driver + + child.execute().mapPartitionsWithStateStore[InternalRow]( + getStateId.checkpointLocation, + getStateId.operatorId, + getStateId.batchId, + groupingAttributes.toStructType, + stateAttributes.toStructType, + sqlContext.sessionState, + Some(sqlContext.streams.stateStoreCoordinator)) { case (store, iterator) => + val updater = new StateStoreUpdater(store) + + // Generate a iterator that returns the rows grouped by the grouping function + // Note that this code ensures that the filtering for timeout occurs only after + // all the data has been processed. This is to ensure that the timeout information of all + // the keys with data is updated before they are processed for timeouts. + val outputIterator = + updater.updateStateForKeysWithData(iterator) ++ updater.updateStateForTimedOutKeys() + + // Return an iterator of all the rows generated by all the keys, such that when fully + // consumed, all the state updates will be committed by the state store + CompletionIterator[InternalRow, Iterator[InternalRow]]( + outputIterator, + { + store.commit() + longMetric("numTotalStateRows") += store.numKeys() + } + ) + } + } + + /** Helper class to update the state store */ + class StateStoreUpdater(store: StateStore) { + + // Converters for translating input keys, values, output data between rows and Java objects + private val getKeyObj = + ObjectOperator.deserializeRowToObject(keyDeserializer, groupingAttributes) + private val getValueObj = + ObjectOperator.deserializeRowToObject(valueDeserializer, dataAttributes) + private val getOutputRow = ObjectOperator.wrapObjectToRow(outputObjAttr.dataType) + + // Converter for translating state rows to Java objects + private val getStateObjFromRow = ObjectOperator.deserializeRowToObject( + stateEncoder.resolveAndBind().deserializer, stateAttributes) + + // Converter for translating state Java objects to rows + private val stateSerializer = { + val encoderSerializer = stateEncoder.namedExpressions + if (isTimeoutEnabled) { + encoderSerializer :+ Literal(KeyedStateImpl.TIMEOUT_TIMESTAMP_NOT_SET) + } else { + encoderSerializer + } + } + private val getStateRowFromObj = ObjectOperator.serializeObjectToRow(stateSerializer) + + // Index of the additional metadata fields in the state row + private val timeoutTimestampIndex = stateAttributes.indexOf(timestampTimeoutAttribute) + + // Metrics + private val numUpdatedStateRows = longMetric("numUpdatedStateRows") + private val numOutputRows = longMetric("numOutputRows") + + /** + * For every group, get the key, values and corresponding state and call the function, + * and return an iterator of rows + */ + def updateStateForKeysWithData(dataIter: Iterator[InternalRow]): Iterator[InternalRow] = { + val groupedIter = GroupedIterator(dataIter, groupingAttributes, child.output) + groupedIter.flatMap { case (keyRow, valueRowIter) => + val keyUnsafeRow = keyRow.asInstanceOf[UnsafeRow] + callFunctionAndUpdateState( + keyUnsafeRow, + valueRowIter, + store.get(keyUnsafeRow), + hasTimedOut = false) + } + } + + /** Find the groups that have timeout set and are timing out right now, and call the function */ + def updateStateForTimedOutKeys(): Iterator[InternalRow] = { + if (isTimeoutEnabled) { + val timingOutKeys = store.filter { case (_, stateRow) => + val timeoutTimestamp = getTimeoutTimestamp(stateRow) + timeoutTimestamp != TIMEOUT_TIMESTAMP_NOT_SET && timeoutTimestamp < batchTimestampMs + } + timingOutKeys.flatMap { case (keyRow, stateRow) => + callFunctionAndUpdateState( + keyRow, + Iterator.empty, + Some(stateRow), + hasTimedOut = true) + } + } else Iterator.empty + } + + /** + * Call the user function on a key's data, update the state store, and return the return data + * iterator. Note that the store updating is lazy, that is, the store will be updated only + * after the returned iterator is fully consumed. + */ + private def callFunctionAndUpdateState( + keyRow: UnsafeRow, + valueRowIter: Iterator[InternalRow], + prevStateRowOption: Option[UnsafeRow], + hasTimedOut: Boolean): Iterator[InternalRow] = { + + val keyObj = getKeyObj(keyRow) // convert key to objects + val valueObjIter = valueRowIter.map(getValueObj.apply) // convert value rows to objects + val stateObjOption = getStateObj(prevStateRowOption) + val keyedState = new KeyedStateImpl( + stateObjOption, batchTimestampMs, isTimeoutEnabled, hasTimedOut) + + // Call function, get the returned objects and convert them to rows + val mappedIterator = func(keyObj, valueObjIter, keyedState).map { obj => + numOutputRows += 1 + getOutputRow(obj) + } + + // When the iterator is consumed, then write changes to state + def onIteratorCompletion: Unit = { + // Has the timeout information changed + + if (keyedState.hasRemoved) { + store.remove(keyRow) + numUpdatedStateRows += 1 + + } else { + val previousTimeoutTimestamp = prevStateRowOption match { + case Some(row) => getTimeoutTimestamp(row) + case None => TIMEOUT_TIMESTAMP_NOT_SET + } + + val stateRowToWrite = if (keyedState.hasUpdated) { + getStateRow(keyedState.get) + } else { + prevStateRowOption.orNull + } + + val hasTimeoutChanged = keyedState.getTimeoutTimestamp != previousTimeoutTimestamp + val shouldWriteState = keyedState.hasUpdated || hasTimeoutChanged + + if (shouldWriteState) { + if (stateRowToWrite == null) { + // This should never happen because checks in KeyedStateImpl should avoid cases + // where empty state would need to be written + throw new IllegalStateException( + "Attempting to write empty state") + } + setTimeoutTimestamp(stateRowToWrite, keyedState.getTimeoutTimestamp) + store.put(keyRow.copy(), stateRowToWrite.copy()) + numUpdatedStateRows += 1 + } + } + } + + // Return an iterator of rows such that fully consumed, the updated state value will be saved + CompletionIterator[InternalRow, Iterator[InternalRow]](mappedIterator, onIteratorCompletion) + } + + /** Returns the state as Java object if defined */ + def getStateObj(stateRowOption: Option[UnsafeRow]): Option[Any] = { + stateRowOption.map(getStateObjFromRow) + } + + /** Returns the row for an updated state */ + def getStateRow(obj: Any): UnsafeRow = { + getStateRowFromObj(obj) + } + + /** Returns the timeout timestamp of a state row is set */ + def getTimeoutTimestamp(stateRow: UnsafeRow): Long = { + if (isTimeoutEnabled) stateRow.getLong(timeoutTimestampIndex) else TIMEOUT_TIMESTAMP_NOT_SET + } + + /** Set the timestamp in a state row */ + def setTimeoutTimestamp(stateRow: UnsafeRow, timeoutTimestamps: Long): Unit = { + if (isTimeoutEnabled) stateRow.setLong(timeoutTimestampIndex, timeoutTimestamps) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index 610ce5e1ebf5..a934c75a0245 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -37,13 +37,13 @@ class IncrementalExecution( val outputMode: OutputMode, val checkpointLocation: String, val currentBatchId: Long, - val currentEventTimeWatermark: Long) + offsetSeqMetadata: OffsetSeqMetadata) extends QueryExecution(sparkSession, logicalPlan) with Logging { // TODO: make this always part of planning. val streamingExtraStrategies = sparkSession.sessionState.planner.StatefulAggregationStrategy +: - sparkSession.sessionState.planner.MapGroupsWithStateStrategy +: + sparkSession.sessionState.planner.FlatMapGroupsWithStateStrategy +: sparkSession.sessionState.planner.StreamingRelationStrategy +: sparkSession.sessionState.planner.StreamingDeduplicationStrategy +: sparkSession.sessionState.experimentalMethods.extraStrategies @@ -88,12 +88,13 @@ class IncrementalExecution( keys, Some(stateId), Some(outputMode), - Some(currentEventTimeWatermark), + Some(offsetSeqMetadata.batchWatermarkMs), agg.withNewChildren( StateStoreRestoreExec( keys, Some(stateId), child) :: Nil)) + case StreamingDeduplicateExec(keys, child, None, None) => val stateId = OperatorStateId(checkpointLocation, operatorId.getAndIncrement(), currentBatchId) @@ -102,13 +103,12 @@ class IncrementalExecution( keys, child, Some(stateId), - Some(currentEventTimeWatermark)) - case FlatMapGroupsWithStateExec( - f, kDeser, vDeser, group, data, output, None, stateDeser, stateSer, child) => + Some(offsetSeqMetadata.batchWatermarkMs)) + + case m: FlatMapGroupsWithStateExec => val stateId = OperatorStateId(checkpointLocation, operatorId.getAndIncrement(), currentBatchId) - FlatMapGroupsWithStateExec( - f, kDeser, vDeser, group, data, output, Some(stateId), stateDeser, stateSer, child) + m.copy(stateId = Some(stateId), batchTimestampMs = offsetSeqMetadata.batchTimestampMs) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/KeyedStateImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/KeyedStateImpl.scala index eee7ec45dd77..ac421d395beb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/KeyedStateImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/KeyedStateImpl.scala @@ -17,15 +17,37 @@ package org.apache.spark.sql.execution.streaming -import org.apache.spark.sql.KeyedState +import org.apache.commons.lang3.StringUtils -/** Internal implementation of the [[KeyedState]] interface. Methods are not thread-safe. */ -private[sql] class KeyedStateImpl[S](optionalValue: Option[S]) extends KeyedState[S] { +import org.apache.spark.sql.streaming.KeyedState +import org.apache.spark.unsafe.types.CalendarInterval + +/** + * Internal implementation of the [[KeyedState]] interface. Methods are not thread-safe. + * @param optionalValue Optional value of the state + * @param batchProcessingTimeMs Processing time of current batch, used to calculate timestamp + * for processing time timeouts + * @param isTimeoutEnabled Whether timeout is enabled. This will be used to check whether the user + * is allowed to configure timeouts. + * @param hasTimedOut Whether the key for which this state wrapped is being created is + * getting timed out or not. + */ +private[sql] class KeyedStateImpl[S]( + optionalValue: Option[S], + batchProcessingTimeMs: Long, + isTimeoutEnabled: Boolean, + override val hasTimedOut: Boolean) extends KeyedState[S] { + + import KeyedStateImpl._ + + // Constructor to create dummy state when using mapGroupsWithState in a batch query + def this(optionalValue: Option[S]) = this( + optionalValue, -1, isTimeoutEnabled = false, hasTimedOut = false) private var value: S = optionalValue.getOrElse(null.asInstanceOf[S]) private var defined: Boolean = optionalValue.isDefined - private var updated: Boolean = false - // whether value has been updated (but not removed) + private var updated: Boolean = false // whether value has been updated (but not removed) private var removed: Boolean = false // whether value has been removed + private var timeoutTimestamp: Long = TIMEOUT_TIMESTAMP_NOT_SET // ========= Public API ========= override def exists: Boolean = defined @@ -60,6 +82,55 @@ private[sql] class KeyedStateImpl[S](optionalValue: Option[S]) extends KeyedStat defined = false updated = false removed = true + timeoutTimestamp = TIMEOUT_TIMESTAMP_NOT_SET + } + + override def setTimeoutDuration(durationMs: Long): Unit = { + if (!isTimeoutEnabled) { + throw new UnsupportedOperationException( + "Cannot set timeout information without enabling timeout in map/flatMapGroupsWithState") + } + if (!defined) { + throw new IllegalStateException( + "Cannot set timeout information without any state value, " + + "state has either not been initialized, or has already been removed") + } + + if (durationMs <= 0) { + throw new IllegalArgumentException("Timeout duration must be positive") + } + if (!removed && batchProcessingTimeMs != NO_BATCH_PROCESSING_TIMESTAMP) { + timeoutTimestamp = durationMs + batchProcessingTimeMs + } else { + // This is being called in a batch query, hence no processing timestamp. + // Just ignore any attempts to set timeout. + } + } + + override def setTimeoutDuration(duration: String): Unit = { + if (StringUtils.isBlank(duration)) { + throw new IllegalArgumentException( + "The window duration, slide duration and start time cannot be null or blank.") + } + val intervalString = if (duration.startsWith("interval")) { + duration + } else { + "interval " + duration + } + val cal = CalendarInterval.fromString(intervalString) + if (cal == null) { + throw new IllegalArgumentException( + s"The provided duration ($duration) is not valid.") + } + if (cal.milliseconds < 0 || cal.months < 0) { + throw new IllegalArgumentException("Timeout duration must be positive") + } + + val delayMs = { + val millisPerMonth = CalendarInterval.MICROS_PER_DAY / 1000 * 31 + cal.milliseconds + cal.months * millisPerMonth + } + setTimeoutDuration(delayMs) } override def toString: String = { @@ -69,12 +140,21 @@ private[sql] class KeyedStateImpl[S](optionalValue: Option[S]) extends KeyedStat // ========= Internal API ========= /** Whether the state has been marked for removing */ - def isRemoved: Boolean = { - removed - } + def hasRemoved: Boolean = removed - /** Whether the state has been been updated */ - def isUpdated: Boolean = { - updated - } + /** Whether the state has been updated */ + def hasUpdated: Boolean = updated + + /** Return timeout timestamp or `TIMEOUT_TIMESTAMP_NOT_SET` if not set */ + def getTimeoutTimestamp: Long = timeoutTimestamp +} + + +private[sql] object KeyedStateImpl { + // Value used in the state row to represent the lack of any timeout timestamp + val TIMEOUT_TIMESTAMP_NOT_SET = -1L + + // Value to represent that no batch processing timestamp is passed to KeyedStateImpl. This is + // used in batch queries where there are no streaming batches and timeouts. + val NO_BATCH_PROCESSING_TIMESTAMP = -1L } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index 40faddccc242..60d5283e6b21 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -590,7 +590,7 @@ class StreamExecution( outputMode, checkpointFile("state"), currentBatchId, - offsetSeqMetadata.batchWatermarkMs) + offsetSeqMetadata) lastExecution.executedPlan // Force the lazy generation of execution plan } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala index ab1204a750fa..f9dd80230e48 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala @@ -73,7 +73,12 @@ private[state] class HDFSBackedStateStoreProvider( hadoopConf: Configuration ) extends StateStoreProvider with Logging { - type MapType = java.util.HashMap[UnsafeRow, UnsafeRow] + // ConcurrentHashMap is used because it generates fail-safe iterators on filtering + // - The iterator is weakly consistent with the map, i.e., iterator's data reflect the values in + // the map when the iterator was created + // - Any updates to the map while iterating through the filtered iterator does not throw + // java.util.ConcurrentModificationException + type MapType = java.util.concurrent.ConcurrentHashMap[UnsafeRow, UnsafeRow] /** Implementation of [[StateStore]] API which is backed by a HDFS-compatible file system */ class HDFSBackedStateStore(val version: Long, mapToUpdate: MapType) @@ -99,6 +104,16 @@ private[state] class HDFSBackedStateStoreProvider( Option(mapToUpdate.get(key)) } + override def filter( + condition: (UnsafeRow, UnsafeRow) => Boolean): Iterator[(UnsafeRow, UnsafeRow)] = { + mapToUpdate + .entrySet + .asScala + .iterator + .filter { entry => condition(entry.getKey, entry.getValue) } + .map { entry => (entry.getKey, entry.getValue) } + } + override def put(key: UnsafeRow, value: UnsafeRow): Unit = { verify(state == UPDATING, "Cannot put after already committed or aborted") @@ -227,7 +242,7 @@ private[state] class HDFSBackedStateStoreProvider( } override def toString(): String = { - s"HDFSStateStore[id = (op=${id.operatorId}, part=${id.partitionId}), dir = $baseDir]" + s"HDFSStateStore[id=(op=${id.operatorId},part=${id.partitionId}),dir=$baseDir]" } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index dcb24b26f78f..eaa558eb6d0e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -50,6 +50,15 @@ trait StateStore { /** Get the current value of a key. */ def get(key: UnsafeRow): Option[UnsafeRow] + /** + * Return an iterator of key-value pairs that satisfy a certain condition. + * Note that the iterator must be fail-safe towards modification to the store, that is, + * it must be based on the snapshot of store the time of this call, and any change made to the + * store while iterating through iterator should not cause the iterator to fail or have + * any affect on the values in the iterator. + */ + def filter(condition: (UnsafeRow, UnsafeRow) => Boolean): Iterator[(UnsafeRow, UnsafeRow)] + /** Put a new value for a key. */ def put(key: UnsafeRow, value: UnsafeRow): Unit diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index c3075a3eacaa..6d2de441eb44 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -19,17 +19,18 @@ package org.apache.spark.sql.execution.streaming import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjection, Predicate} -import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, LogicalKeyedState} +import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, LogicalKeyedState, ProcessingTimeTimeout} import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution, Partitioning} import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ import org.apache.spark.sql.execution._ -import org.apache.spark.sql.execution.metric.SQLMetrics +import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.sql.execution.streaming.state._ -import org.apache.spark.sql.streaming.OutputMode -import org.apache.spark.sql.types.{DataType, NullType, StructType} +import org.apache.spark.sql.streaming.{KeyedStateTimeout, OutputMode} +import org.apache.spark.sql.types._ import org.apache.spark.util.CompletionIterator @@ -256,94 +257,6 @@ case class StateStoreSaveExec( override def outputPartitioning: Partitioning = child.outputPartitioning } - -/** Physical operator for executing streaming flatMapGroupsWithState. */ -case class FlatMapGroupsWithStateExec( - func: (Any, Iterator[Any], LogicalKeyedState[Any]) => Iterator[Any], - keyDeserializer: Expression, - valueDeserializer: Expression, - groupingAttributes: Seq[Attribute], - dataAttributes: Seq[Attribute], - outputObjAttr: Attribute, - stateId: Option[OperatorStateId], - stateDeserializer: Expression, - stateSerializer: Seq[NamedExpression], - child: SparkPlan) extends UnaryExecNode with ObjectProducerExec with StateStoreWriter { - - override def outputPartitioning: Partitioning = child.outputPartitioning - - /** Distribute by grouping attributes */ - override def requiredChildDistribution: Seq[Distribution] = - ClusteredDistribution(groupingAttributes) :: Nil - - /** Ordering needed for using GroupingIterator */ - override def requiredChildOrdering: Seq[Seq[SortOrder]] = - Seq(groupingAttributes.map(SortOrder(_, Ascending))) - - override protected def doExecute(): RDD[InternalRow] = { - child.execute().mapPartitionsWithStateStore[InternalRow]( - getStateId.checkpointLocation, - getStateId.operatorId, - getStateId.batchId, - groupingAttributes.toStructType, - child.output.toStructType, - sqlContext.sessionState, - Some(sqlContext.streams.stateStoreCoordinator)) { (store, iter) => - val numTotalStateRows = longMetric("numTotalStateRows") - val numUpdatedStateRows = longMetric("numUpdatedStateRows") - val numOutputRows = longMetric("numOutputRows") - - // Generate a iterator that returns the rows grouped by the grouping function - val groupedIter = GroupedIterator(iter, groupingAttributes, child.output) - - // Converters to and from object and rows - val getKeyObj = ObjectOperator.deserializeRowToObject(keyDeserializer, groupingAttributes) - val getValueObj = ObjectOperator.deserializeRowToObject(valueDeserializer, dataAttributes) - val getOutputRow = ObjectOperator.wrapObjectToRow(outputObjAttr.dataType) - val getStateObj = - ObjectOperator.deserializeRowToObject(stateDeserializer) - val outputStateObj = ObjectOperator.serializeObjectToRow(stateSerializer) - - // For every group, get the key, values and corresponding state and call the function, - // and return an iterator of rows - val allRowsIterator = groupedIter.flatMap { case (keyRow, valueRowIter) => - - val key = keyRow.asInstanceOf[UnsafeRow] - val keyObj = getKeyObj(keyRow) // convert key to objects - val valueObjIter = valueRowIter.map(getValueObj.apply) // convert value rows to objects - val stateObjOption = store.get(key).map(getStateObj) // get existing state if any - val wrappedState = new KeyedStateImpl(stateObjOption) - val mappedIterator = func(keyObj, valueObjIter, wrappedState).map { obj => - numOutputRows += 1 - getOutputRow(obj) // convert back to rows - } - - // Return an iterator of rows generated this key, - // such that fully consumed, the updated state value will be saved - CompletionIterator[InternalRow, Iterator[InternalRow]]( - mappedIterator, { - // When the iterator is consumed, then write changes to state - if (wrappedState.isRemoved) { - store.remove(key) - numUpdatedStateRows += 1 - } else if (wrappedState.isUpdated) { - store.put(key, outputStateObj(wrappedState.get)) - numUpdatedStateRows += 1 - } - }) - } - - // Return an iterator of all the rows generated by all the keys, such that when fully - // consumer, all the state updates will be committed by the state store - CompletionIterator[InternalRow, Iterator[InternalRow]](allRowsIterator, { - store.commit() - numTotalStateRows += store.numKeys() - }) - } - } -} - - /** Physical operator for executing streaming Deduplicate. */ case class StreamingDeduplicateExec( keyExpressions: Seq[Attribute], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/KeyedState.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/KeyedState.scala new file mode 100644 index 000000000000..6b4b1ced98a3 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/KeyedState.scala @@ -0,0 +1,214 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming + +import org.apache.spark.annotation.{Experimental, InterfaceStability} +import org.apache.spark.sql.{Encoder, KeyValueGroupedDataset} +import org.apache.spark.sql.catalyst.plans.logical.LogicalKeyedState + +/** + * :: Experimental :: + * + * Wrapper class for interacting with keyed state data in `mapGroupsWithState` and + * `flatMapGroupsWithState` operations on + * [[KeyValueGroupedDataset]]. + * + * Detail description on `[map/flatMap]GroupsWithState` operation + * -------------------------------------------------------------- + * Both, `mapGroupsWithState` and `flatMapGroupsWithState` in [[KeyValueGroupedDataset]] + * will invoke the user-given function on each group (defined by the grouping function in + * `Dataset.groupByKey()`) while maintaining user-defined per-group state between invocations. + * For a static batch Dataset, the function will be invoked once per group. For a streaming + * Dataset, the function will be invoked for each group repeatedly in every trigger. + * That is, in every batch of the `streaming.StreamingQuery`, + * the function will be invoked once for each group that has data in the trigger. Furthermore, + * if timeout is set, then the function will invoked on timed out keys (more detail below). + * + * The function is invoked with following parameters. + * - The key of the group. + * - An iterator containing all the values for this key. + * - A user-defined state object set by previous invocations of the given function. + * In case of a batch Dataset, there is only one invocation and state object will be empty as + * there is no prior state. Essentially, for batch Datasets, `[map/flatMap]GroupsWithState` + * is equivalent to `[map/flatMap]Groups` and any updates to the state and/or timeouts have + * no effect. + * + * Important points to note about the function. + * - In a trigger, the function will be called only the groups present in the batch. So do not + * assume that the function will be called in every trigger for every group that has state. + * - There is no guaranteed ordering of values in the iterator in the function, neither with + * batch, nor with streaming Datasets. + * - All the data will be shuffled before applying the function. + * - If timeout is set, then the function will also be called with no values. + * See more details on KeyedStateTimeout` below. + * + * Important points to note about using `KeyedState`. + * - The value of the state cannot be null. So updating state with null will throw + * `IllegalArgumentException`. + * - Operations on `KeyedState` are not thread-safe. This is to avoid memory barriers. + * - If `remove()` is called, then `exists()` will return `false`, + * `get()` will throw `NoSuchElementException` and `getOption()` will return `None` + * - After that, if `update(newState)` is called, then `exists()` will again return `true`, + * `get()` and `getOption()`will return the updated value. + * + * Important points to note about using `KeyedStateTimeout`. + * - The timeout type is a global param across all the keys (set as `timeout` param in + * `[map|flatMap]GroupsWithState`, but the exact timeout duration is configurable per key + * (by calling `setTimeout...()` in `KeyedState`). + * - When the timeout occurs for a key, the function is called with no values, and + * `KeyedState.hasTimedOut()` set to true. + * - The timeout is reset for key every time the function is called on the key, that is, + * when the key has new data, or the key has timed out. So the user has to set the timeout + * duration every time the function is called, otherwise there will not be any timeout set. + * - Guarantees provided on processing-time-based timeout of key, when timeout duration is D ms: + * - Timeout will never be called before real clock time has advanced by D ms + * - Timeout will be called eventually when there is a trigger in the query + * (i.e. after D ms). So there is a no strict upper bound on when the timeout would occur. + * For example, the trigger interval of the query will affect when the timeout is actually hit. + * If there is no data in the stream (for any key) for a while, then their will not be + * any trigger and timeout will not be hit until there is data. + * + * Scala example of using KeyedState in `mapGroupsWithState`: + * {{{ + * // A mapping function that maintains an integer state for string keys and returns a string. + * // Additionally, it sets a timeout to remove the state if it has not received data for an hour. + * def mappingFunction(key: String, value: Iterator[Int], state: KeyedState[Int]): String = { + * + * if (state.hasTimedOut) { // If called when timing out, remove the state + * state.remove() + * + * } else if (state.exists) { // If state exists, use it for processing + * val existingState = state.get // Get the existing state + * val shouldRemove = ... // Decide whether to remove the state + * if (shouldRemove) { + * state.remove() // Remove the state + * + * } else { + * val newState = ... + * state.update(newState) // Set the new state + * state.setTimeoutDuration("1 hour") // Set the timeout + * } + * + * } else { + * val initialState = ... + * state.update(initialState) // Set the initial state + * state.setTimeoutDuration("1 hour") // Set the timeout + * } + * ... + * // return something + * } + * + * dataset + * .groupByKey(...) + * .mapGroupsWithState(KeyedStateTimeout.ProcessingTimeTimeout)(mappingFunction) + * }}} + * + * Java example of using `KeyedState`: + * {{{ + * // A mapping function that maintains an integer state for string keys and returns a string. + * // Additionally, it sets a timeout to remove the state if it has not received data for an hour. + * MapGroupsWithStateFunction mappingFunction = + * new MapGroupsWithStateFunction() { + * + * @Override + * public String call(String key, Iterator value, KeyedState state) { + * if (state.hasTimedOut()) { // If called when timing out, remove the state + * state.remove(); + * + * } else if (state.exists()) { // If state exists, use it for processing + * int existingState = state.get(); // Get the existing state + * boolean shouldRemove = ...; // Decide whether to remove the state + * if (shouldRemove) { + * state.remove(); // Remove the state + * + * } else { + * int newState = ...; + * state.update(newState); // Set the new state + * state.setTimeoutDuration("1 hour"); // Set the timeout + * } + * + * } else { + * int initialState = ...; // Set the initial state + * state.update(initialState); + * state.setTimeoutDuration("1 hour"); // Set the timeout + * } + * ... +* // return something + * } + * }; + * + * dataset + * .groupByKey(...) + * .mapGroupsWithState( + * mappingFunction, Encoders.INT, Encoders.STRING, KeyedStateTimeout.ProcessingTimeTimeout); + * }}} + * + * @tparam S User-defined type of the state to be stored for each key. Must be encodable into + * Spark SQL types (see [[Encoder]] for more details). + * @since 2.2.0 + */ +@Experimental +@InterfaceStability.Evolving +trait KeyedState[S] extends LogicalKeyedState[S] { + + /** Whether state exists or not. */ + def exists: Boolean + + /** Get the state value if it exists, or throw NoSuchElementException. */ + @throws[NoSuchElementException]("when state does not exist") + def get: S + + /** Get the state value as a scala Option. */ + def getOption: Option[S] + + /** + * Update the value of the state. Note that `null` is not a valid value, and it throws + * IllegalArgumentException. + */ + @throws[IllegalArgumentException]("when updating with null") + def update(newState: S): Unit + + /** Remove this keyed state. Note that this resets any timeout configuration as well. */ + def remove(): Unit + + /** + * Whether the function has been called because the key has timed out. + * @note This can return true only when timeouts are enabled in `[map/flatmap]GroupsWithStates`. + */ + def hasTimedOut: Boolean + + /** + * Set the timeout duration in ms for this key. + * @note Timeouts must be enabled in `[map/flatmap]GroupsWithStates`. + */ + @throws[IllegalArgumentException]("if 'durationMs' is not positive") + @throws[IllegalStateException]("when state is either not initialized, or already removed") + @throws[UnsupportedOperationException]( + "if 'timeout' has not been enabled in [map|flatMap]GroupsWithState in a streaming query") + def setTimeoutDuration(durationMs: Long): Unit + + /** + * Set the timeout duration for this key as a string. For example, "1 hour", "2 days", etc. + * @note, Timeouts must be enabled in `[map/flatmap]GroupsWithStates`. + */ + @throws[IllegalArgumentException]("if 'duration' is not a valid duration") + @throws[IllegalStateException]("when state is either not initialized, or already removed") + @throws[UnsupportedOperationException]( + "if 'timeout' has not been enabled in [map|flatMap]GroupsWithState in a streaming query") + def setTimeoutDuration(duration: String): Unit +} diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index 439cac3dfbcb..ca9e5ad2ea86 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -23,6 +23,7 @@ import java.sql.Timestamp; import java.util.*; +import org.apache.spark.sql.streaming.KeyedStateTimeout; import org.apache.spark.sql.streaming.OutputMode; import scala.Tuple2; import scala.Tuple3; @@ -208,7 +209,8 @@ public void testGroupBy() { }, OutputMode.Append(), Encoders.LONG(), - Encoders.STRING()); + Encoders.STRING(), + KeyedStateTimeout.NoTimeout()); Assert.assertEquals(asSet("1a", "3foobar"), toSet(flatMapped2.collectAsList())); diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala index e848f74e3159..ebb7422765eb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala @@ -123,6 +123,30 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth assert(getDataFromFiles(provider, version = 2) === Set("b" -> 2, "c" -> 4)) } + test("filter and concurrent updates") { + val provider = newStoreProvider() + + // Verify state before starting a new set of updates + assert(provider.latestIterator.isEmpty) + val store = provider.getStore(0) + put(store, "a", 1) + put(store, "b", 2) + + // Updates should work while iterating of filtered entries + val filtered = store.filter { case (keyRow, _) => rowToString(keyRow) == "a" } + filtered.foreach { case (keyRow, valueRow) => + store.put(keyRow, intToRow(rowToInt(valueRow) + 1)) + } + assert(get(store, "a") === Some(2)) + + // Removes should work while iterating of filtered entries + val filtered2 = store.filter { case (keyRow, _) => rowToString(keyRow) == "b" } + filtered2.foreach { case (keyRow, _) => + store.remove(keyRow) + } + assert(get(store, "b") === None) + } + test("updates iterator with all combos of updates and removes") { val provider = newStoreProvider() var currentVersion: Int = 0 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala index 902b842e97aa..7daa5e6a0f61 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala @@ -17,20 +17,33 @@ package org.apache.spark.sql.streaming +import java.util +import java.util.concurrent.ConcurrentHashMap + import org.scalatest.BeforeAndAfterAll import org.apache.spark.SparkException -import org.apache.spark.sql.KeyedState +import org.apache.spark.api.java.function.FlatMapGroupsWithStateFunction +import org.apache.spark.sql.Encoder +import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.catalyst.plans.logical.FlatMapGroupsWithState +import org.apache.spark.sql.catalyst.plans.physical.UnknownPartitioning import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ -import org.apache.spark.sql.execution.streaming.{KeyedStateImpl, MemoryStream} -import org.apache.spark.sql.execution.streaming.state.StateStore +import org.apache.spark.sql.execution.RDDScanExec +import org.apache.spark.sql.execution.streaming.{FlatMapGroupsWithStateExec, KeyedStateImpl, MemoryStream} +import org.apache.spark.sql.execution.streaming.state.{StateStore, StateStoreId, StoreUpdate} +import org.apache.spark.sql.streaming.FlatMapGroupsWithStateSuite.MemoryStateStore +import org.apache.spark.sql.types.{DataType, IntegerType} /** Class to check custom state types */ case class RunningCount(count: Long) +case class Result(key: Long, count: Int) + class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAfterAll { import testImplicits._ + import KeyedStateImpl._ override def afterAll(): Unit = { super.afterAll() @@ -54,8 +67,8 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf } } assert(state.getOption === expectedData) - assert(state.isUpdated === shouldBeUpdated) - assert(state.isRemoved === shouldBeRemoved) + assert(state.hasUpdated === shouldBeUpdated) + assert(state.hasRemoved === shouldBeRemoved) } // Updating empty state @@ -83,6 +96,79 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf } } + test("KeyedState - setTimeoutDuration, hasTimedOut") { + import KeyedStateImpl._ + var state: KeyedStateImpl[Int] = null + + // When isTimeoutEnabled = false, then setTimeoutDuration() is not allowed + for (initState <- Seq(None, Some(5))) { + // for different initial state + state = new KeyedStateImpl(initState, 1000, isTimeoutEnabled = false, hasTimedOut = false) + assert(state.hasTimedOut === false) + assert(state.getTimeoutTimestamp === TIMEOUT_TIMESTAMP_NOT_SET) + intercept[UnsupportedOperationException] { + state.setTimeoutDuration(1000) + } + intercept[UnsupportedOperationException] { + state.setTimeoutDuration("1 day") + } + assert(state.getTimeoutTimestamp === TIMEOUT_TIMESTAMP_NOT_SET) + } + + def testTimeoutNotAllowed(): Unit = { + intercept[IllegalStateException] { + state.setTimeoutDuration(1000) + } + assert(state.getTimeoutTimestamp === TIMEOUT_TIMESTAMP_NOT_SET) + intercept[IllegalStateException] { + state.setTimeoutDuration("2 second") + } + assert(state.getTimeoutTimestamp === TIMEOUT_TIMESTAMP_NOT_SET) + } + + // When isTimeoutEnabled = true, then setTimeoutDuration() is not allowed until the + // state is be defined + state = new KeyedStateImpl(None, 1000, isTimeoutEnabled = true, hasTimedOut = false) + assert(state.hasTimedOut === false) + assert(state.getTimeoutTimestamp === TIMEOUT_TIMESTAMP_NOT_SET) + testTimeoutNotAllowed() + + // After state has been set, setTimeoutDuration() is allowed, and + // getTimeoutTimestamp returned correct timestamp + state.update(5) + assert(state.hasTimedOut === false) + assert(state.getTimeoutTimestamp === TIMEOUT_TIMESTAMP_NOT_SET) + state.setTimeoutDuration(1000) + assert(state.getTimeoutTimestamp === 2000) + state.setTimeoutDuration("2 second") + assert(state.getTimeoutTimestamp === 3000) + assert(state.hasTimedOut === false) + + // setTimeoutDuration() with negative values or 0 is not allowed + def testIllegalTimeout(body: => Unit): Unit = { + intercept[IllegalArgumentException] { body } + assert(state.getTimeoutTimestamp === TIMEOUT_TIMESTAMP_NOT_SET) + } + state = new KeyedStateImpl(Some(5), 1000, isTimeoutEnabled = true, hasTimedOut = false) + testIllegalTimeout { state.setTimeoutDuration(-1000) } + testIllegalTimeout { state.setTimeoutDuration(0) } + testIllegalTimeout { state.setTimeoutDuration("-2 second") } + testIllegalTimeout { state.setTimeoutDuration("-1 month") } + testIllegalTimeout { state.setTimeoutDuration("1 month -1 day") } + + // Test remove() clear timeout timestamp, and setTimeoutDuration() is not allowed after that + state = new KeyedStateImpl(Some(5), 1000, isTimeoutEnabled = true, hasTimedOut = false) + state.remove() + assert(state.hasTimedOut === false) + assert(state.getTimeoutTimestamp === TIMEOUT_TIMESTAMP_NOT_SET) + testTimeoutNotAllowed() + + // Test hasTimedOut = true + state = new KeyedStateImpl(Some(5), 1000, isTimeoutEnabled = true, hasTimedOut = true) + assert(state.hasTimedOut === true) + assert(state.getTimeoutTimestamp === TIMEOUT_TIMESTAMP_NOT_SET) + } + test("KeyedState - primitive type") { var intState = new KeyedStateImpl[Int](None) intercept[NoSuchElementException] { @@ -100,6 +186,151 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf } } + // Values used for testing StateStoreUpdater + val currentTimestamp = 1000 + val beforeCurrentTimestamp = 999 + val afterCurrentTimestamp = 1001 + + // Tests for StateStoreUpdater.updateStateForKeysWithData() when timeout is disabled + for (priorState <- Seq(None, Some(0))) { + val priorStateStr = if (priorState.nonEmpty) "prior state set" else "no prior state" + val testName = s"timeout disabled - $priorStateStr - " + + testStateUpdateWithData( + testName + "no update", + stateUpdates = state => { /* do nothing */ }, + timeoutType = KeyedStateTimeout.NoTimeout, + priorState = priorState, + expectedState = priorState) // should not change + + testStateUpdateWithData( + testName + "state updated", + stateUpdates = state => { state.update(5) }, + timeoutType = KeyedStateTimeout.NoTimeout, + priorState = priorState, + expectedState = Some(5)) // should change + + testStateUpdateWithData( + testName + "state removed", + stateUpdates = state => { state.remove() }, + timeoutType = KeyedStateTimeout.NoTimeout, + priorState = priorState, + expectedState = None) // should be removed + } + + // Tests for StateStoreUpdater.updateStateForKeysWithData() when timeout is enabled + for (priorState <- Seq(None, Some(0))) { + for (priorTimeoutTimestamp <- Seq(TIMEOUT_TIMESTAMP_NOT_SET, 1000)) { + var testName = s"timeout enabled - " + if (priorState.nonEmpty) { + testName += "prior state set, " + if (priorTimeoutTimestamp == 1000) { + testName += "prior timeout set - " + } else { + testName += "no prior timeout - " + } + } else { + testName += "no prior state - " + } + + testStateUpdateWithData( + testName + "no update", + stateUpdates = state => { /* do nothing */ }, + timeoutType = KeyedStateTimeout.ProcessingTimeTimeout, + priorState = priorState, + priorTimeoutTimestamp = priorTimeoutTimestamp, + expectedState = priorState, // state should not change + expectedTimeoutTimestamp = TIMEOUT_TIMESTAMP_NOT_SET) // timestamp should be reset + + testStateUpdateWithData( + testName + "state updated", + stateUpdates = state => { state.update(5) }, + timeoutType = KeyedStateTimeout.ProcessingTimeTimeout, + priorState = priorState, + priorTimeoutTimestamp = priorTimeoutTimestamp, + expectedState = Some(5), // state should change + expectedTimeoutTimestamp = TIMEOUT_TIMESTAMP_NOT_SET) // timestamp should be reset + + testStateUpdateWithData( + testName + "state removed", + stateUpdates = state => { state.remove() }, + timeoutType = KeyedStateTimeout.ProcessingTimeTimeout, + priorState = priorState, + priorTimeoutTimestamp = priorTimeoutTimestamp, + expectedState = None) // state should be removed + + testStateUpdateWithData( + testName + "timeout and state updated", + stateUpdates = state => { state.update(5); state.setTimeoutDuration(5000) }, + timeoutType = KeyedStateTimeout.ProcessingTimeTimeout, + priorState = priorState, + priorTimeoutTimestamp = priorTimeoutTimestamp, + expectedState = Some(5), // state should change + expectedTimeoutTimestamp = currentTimestamp + 5000) // timestamp should change + } + } + + // Tests for StateStoreUpdater.updateStateForTimedOutKeys() + val preTimeoutState = Some(5) + + testStateUpdateWithTimeout( + "should not timeout", + stateUpdates = state => { assert(false, "function called without timeout") }, + priorTimeoutTimestamp = afterCurrentTimestamp, + expectedState = preTimeoutState, // state should not change + expectedTimeoutTimestamp = afterCurrentTimestamp) // timestamp should not change + + testStateUpdateWithTimeout( + "should timeout - no update/remove", + stateUpdates = state => { /* do nothing */ }, + priorTimeoutTimestamp = beforeCurrentTimestamp, + expectedState = preTimeoutState, // state should not change + expectedTimeoutTimestamp = TIMEOUT_TIMESTAMP_NOT_SET) // timestamp should be reset + + testStateUpdateWithTimeout( + "should timeout - update state", + stateUpdates = state => { state.update(5) }, + priorTimeoutTimestamp = beforeCurrentTimestamp, + expectedState = Some(5), // state should change + expectedTimeoutTimestamp = TIMEOUT_TIMESTAMP_NOT_SET) // timestamp should be reset + + testStateUpdateWithTimeout( + "should timeout - remove state", + stateUpdates = state => { state.remove() }, + priorTimeoutTimestamp = beforeCurrentTimestamp, + expectedState = None, // state should be removed + expectedTimeoutTimestamp = TIMEOUT_TIMESTAMP_NOT_SET) + + testStateUpdateWithTimeout( + "should timeout - timeout updated", + stateUpdates = state => { state.setTimeoutDuration(2000) }, + priorTimeoutTimestamp = beforeCurrentTimestamp, + expectedState = preTimeoutState, // state should not change + expectedTimeoutTimestamp = currentTimestamp + 2000) // timestamp should change + + testStateUpdateWithTimeout( + "should timeout - timeout and state updated", + stateUpdates = state => { state.update(5); state.setTimeoutDuration(2000) }, + priorTimeoutTimestamp = beforeCurrentTimestamp, + expectedState = Some(5), // state should change + expectedTimeoutTimestamp = currentTimestamp + 2000) // timestamp should change + + test("StateStoreUpdater - rows are cloned before writing to StateStore") { + // function for running count + val func = (key: Int, values: Iterator[Int], state: KeyedState[Int]) => { + state.update(state.getOption.getOrElse(0) + values.size) + Iterator.empty + } + val store = newStateStore() + val plan = newFlatMapGroupsWithStateExec(func) + val updater = new plan.StateStoreUpdater(store) + val data = Seq(1, 1, 2) + val returnIter = updater.updateStateForKeysWithData(data.iterator.map(intToRow)) + returnIter.size // consume the iterator to force store updates + val storeData = store.iterator.map { case (k, v) => (rowToInt(k), rowToInt(v)) }.toSet + assert(storeData === Set((1, 2), (2, 1))) + } + test("flatMapGroupsWithState - streaming") { // Function to maintain running count up to 2, and then remove the count // Returns the data and the count if state is defined, otherwise does not return anything @@ -119,7 +350,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf val result = inputData.toDS() .groupByKey(x => x) - .flatMapGroupsWithState(stateFunc, Update) // State: Int, Out: (Str, Str) + .flatMapGroupsWithState(Update, KeyedStateTimeout.NoTimeout)(stateFunc) testStream(result, Update)( AddData(inputData, "a"), @@ -162,8 +393,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf val result = inputData.toDS() .groupByKey(x => x) - .flatMapGroupsWithState(stateFunc, Update) // State: Int, Out: (Str, Str) - + .flatMapGroupsWithState(Update, KeyedStateTimeout.NoTimeout)(stateFunc) testStream(result, Update)( AddData(inputData, "a", "a", "b"), CheckLastBatch(("a", "1"), ("a", "2"), ("b", "1")), @@ -178,59 +408,118 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf ) } + test("flatMapGroupsWithState - streaming + aggregation") { + // Function to maintain running count up to 2, and then remove the count + // Returns the data and the count (-1 if count reached beyond 2 and state was just removed) + val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => { + + val count = state.getOption.map(_.count).getOrElse(0L) + values.size + if (count == 3) { + state.remove() + Iterator(key -> "-1") + } else { + state.update(RunningCount(count)) + Iterator(key -> count.toString) + } + } + + val inputData = MemoryStream[String] + val result = + inputData.toDS() + .groupByKey(x => x) + .flatMapGroupsWithState(Append, KeyedStateTimeout.NoTimeout)(stateFunc) + .groupByKey(_._1) + .count() + + testStream(result, Complete)( + AddData(inputData, "a"), + CheckLastBatch(("a", 1)), + AddData(inputData, "a", "b"), + // mapGroups generates ("a", "2"), ("b", "1"); so increases counts of a and b by 1 + CheckLastBatch(("a", 2), ("b", 1)), + StopStream, + StartStream(), + AddData(inputData, "a", "b"), + // mapGroups should remove state for "a" and generate ("a", "-1"), ("b", "2") ; + // so increment a and b by 1 + CheckLastBatch(("a", 3), ("b", 2)), + StopStream, + StartStream(), + AddData(inputData, "a", "c"), + // mapGroups should recreate state for "a" and generate ("a", "1"), ("c", "1") ; + // so increment a and c by 1 + CheckLastBatch(("a", 4), ("b", 2), ("c", 1)) + ) + } + test("flatMapGroupsWithState - batch") { // Function that returns running count only if its even, otherwise does not return val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => { if (state.exists) throw new IllegalArgumentException("state.exists should be false") Iterator((key, values.size)) } - checkAnswer( - Seq("a", "a", "b").toDS.groupByKey(x => x).flatMapGroupsWithState(stateFunc, Update).toDF, - Seq(("a", 2), ("b", 1)).toDF) + val df = Seq("a", "a", "b").toDS + .groupByKey(x => x) + .flatMapGroupsWithState(Update, KeyedStateTimeout.NoTimeout)(stateFunc).toDF + checkAnswer(df, Seq(("a", 2), ("b", 1)).toDF) } - test("mapGroupsWithState - streaming") { + test("flatMapGroupsWithState - streaming with processing time timeout") { // Function to maintain running count up to 2, and then remove the count // Returns the data and the count (-1 if count reached beyond 2 and state was just removed) val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => { - - val count = state.getOption.map(_.count).getOrElse(0L) + values.size - if (count == 3) { + if (state.hasTimedOut) { state.remove() - (key, "-1") + Iterator((key, "-1")) } else { + val count = state.getOption.map(_.count).getOrElse(0L) + values.size state.update(RunningCount(count)) - (key, count.toString) + state.setTimeoutDuration("10 seconds") + Iterator((key, count.toString)) } } + val clock = new StreamManualClock val inputData = MemoryStream[String] + val timeout = KeyedStateTimeout.ProcessingTimeTimeout val result = inputData.toDS() .groupByKey(x => x) - .mapGroupsWithState(stateFunc) // Types = State: MyState, Out: (Str, Str) + .flatMapGroupsWithState(Update, timeout)(stateFunc) testStream(result, Update)( + StartStream(ProcessingTime("1 second"), triggerClock = clock), AddData(inputData, "a"), + AdvanceManualClock(1 * 1000), CheckLastBatch(("a", "1")), assertNumStateRows(total = 1, updated = 1), - AddData(inputData, "a", "b"), - CheckLastBatch(("a", "2"), ("b", "1")), - assertNumStateRows(total = 2, updated = 2), - StopStream, - StartStream(), - AddData(inputData, "a", "b"), // should remove state for "a" and return count as -1 + + AddData(inputData, "b"), + AdvanceManualClock(1 * 1000), + CheckLastBatch(("b", "1")), + assertNumStateRows(total = 2, updated = 1), + + AddData(inputData, "b"), + AdvanceManualClock(10 * 1000), CheckLastBatch(("a", "-1"), ("b", "2")), assertNumStateRows(total = 1, updated = 2), + StopStream, - StartStream(), - AddData(inputData, "a", "c"), // should recreate state for "a" and return count as 1 - CheckLastBatch(("a", "1"), ("c", "1")), - assertNumStateRows(total = 3, updated = 2) + StartStream(ProcessingTime("1 second"), triggerClock = clock), + + AddData(inputData, "c"), + AdvanceManualClock(20 * 1000), + CheckLastBatch(("b", "-1"), ("c", "1")), + assertNumStateRows(total = 1, updated = 2), + + AddData(inputData, "c"), + AdvanceManualClock(20 * 1000), + CheckLastBatch(("c", "2")), + assertNumStateRows(total = 1, updated = 1) ) } - test("flatMapGroupsWithState - streaming + aggregation") { + test("mapGroupsWithState - streaming") { // Function to maintain running count up to 2, and then remove the count // Returns the data and the count (-1 if count reached beyond 2 and state was just removed) val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => { @@ -238,10 +527,10 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf val count = state.getOption.map(_.count).getOrElse(0L) + values.size if (count == 3) { state.remove() - Iterator(key -> "-1") + (key, "-1") } else { state.update(RunningCount(count)) - Iterator(key -> count.toString) + (key, count.toString) } } @@ -249,28 +538,25 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf val result = inputData.toDS() .groupByKey(x => x) - .flatMapGroupsWithState(stateFunc, Append) // Types = State: MyState, Out: (Str, Str) - .groupByKey(_._1) - .count() + .mapGroupsWithState(stateFunc) // Types = State: MyState, Out: (Str, Str) - testStream(result, Complete)( + testStream(result, Update)( AddData(inputData, "a"), - CheckLastBatch(("a", 1)), + CheckLastBatch(("a", "1")), + assertNumStateRows(total = 1, updated = 1), AddData(inputData, "a", "b"), - // mapGroups generates ("a", "2"), ("b", "1"); so increases counts of a and b by 1 - CheckLastBatch(("a", 2), ("b", 1)), + CheckLastBatch(("a", "2"), ("b", "1")), + assertNumStateRows(total = 2, updated = 2), StopStream, StartStream(), - AddData(inputData, "a", "b"), - // mapGroups should remove state for "a" and generate ("a", "-1"), ("b", "2") ; - // so increment a and b by 1 - CheckLastBatch(("a", 3), ("b", 2)), + AddData(inputData, "a", "b"), // should remove state for "a" and return count as -1 + CheckLastBatch(("a", "-1"), ("b", "2")), + assertNumStateRows(total = 1, updated = 2), StopStream, StartStream(), - AddData(inputData, "a", "c"), - // mapGroups should recreate state for "a" and generate ("a", "1"), ("c", "1") ; - // so increment a and c by 1 - CheckLastBatch(("a", 4), ("b", 2), ("c", 1)) + AddData(inputData, "a", "c"), // should recreate state for "a" and return count as 1 + CheckLastBatch(("a", "1"), ("c", "1")), + assertNumStateRows(total = 3, updated = 2) ) } @@ -322,23 +608,185 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf ) } + test("output partitioning is unknown") { + val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => key + val inputData = MemoryStream[String] + val result = inputData.toDS.groupByKey(x => x).mapGroupsWithState(stateFunc) + result + testStream(result, Update)( + AddData(inputData, "a"), + CheckLastBatch("a"), + AssertOnQuery(_.lastExecution.executedPlan.outputPartitioning === UnknownPartitioning(0)) + ) + } + test("disallow complete mode") { - val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => { + val stateFunc = (key: String, values: Iterator[String], state: KeyedState[Int]) => { Iterator[String]() } var e = intercept[IllegalArgumentException] { - MemoryStream[String].toDS().groupByKey(x => x).flatMapGroupsWithState(stateFunc, Complete) + MemoryStream[String].toDS().groupByKey(x => x).flatMapGroupsWithState( + OutputMode.Complete, KeyedStateTimeout.NoTimeout)(stateFunc) } assert(e.getMessage === "The output mode of function should be append or update") + val javaStateFunc = new FlatMapGroupsWithStateFunction[String, String, Int, String] { + import java.util.{Iterator => JIterator} + override def call( + key: String, + values: JIterator[String], + state: KeyedState[Int]): JIterator[String] = { null } + } e = intercept[IllegalArgumentException] { - MemoryStream[String].toDS().groupByKey(x => x).flatMapGroupsWithState(stateFunc, "complete") + MemoryStream[String].toDS().groupByKey(x => x).flatMapGroupsWithState( + javaStateFunc, OutputMode.Complete, + implicitly[Encoder[Int]], implicitly[Encoder[String]], KeyedStateTimeout.NoTimeout) } assert(e.getMessage === "The output mode of function should be append or update") } + + def testStateUpdateWithData( + testName: String, + stateUpdates: KeyedState[Int] => Unit, + timeoutType: KeyedStateTimeout = KeyedStateTimeout.NoTimeout, + priorState: Option[Int], + priorTimeoutTimestamp: Long = TIMEOUT_TIMESTAMP_NOT_SET, + expectedState: Option[Int] = None, + expectedTimeoutTimestamp: Long = TIMEOUT_TIMESTAMP_NOT_SET): Unit = { + + if (priorState.isEmpty && priorTimeoutTimestamp != TIMEOUT_TIMESTAMP_NOT_SET) { + return // there can be no prior timestamp, when there is no prior state + } + test(s"StateStoreUpdater - updates with data - $testName") { + val mapGroupsFunc = (key: Int, values: Iterator[Int], state: KeyedState[Int]) => { + assert(state.hasTimedOut === false, "hasTimedOut not false") + assert(values.nonEmpty, "Some value is expected") + stateUpdates(state) + Iterator.empty + } + testStateUpdate( + testTimeoutUpdates = false, mapGroupsFunc, timeoutType, + priorState, priorTimeoutTimestamp, expectedState, expectedTimeoutTimestamp) + } + } + + def testStateUpdateWithTimeout( + testName: String, + stateUpdates: KeyedState[Int] => Unit, + priorTimeoutTimestamp: Long, + expectedState: Option[Int], + expectedTimeoutTimestamp: Long = TIMEOUT_TIMESTAMP_NOT_SET): Unit = { + + test(s"StateStoreUpdater - updates for timeout - $testName") { + val mapGroupsFunc = (key: Int, values: Iterator[Int], state: KeyedState[Int]) => { + assert(state.hasTimedOut === true, "hasTimedOut not true") + assert(values.isEmpty, "values not empty") + stateUpdates(state) + Iterator.empty + } + testStateUpdate( + testTimeoutUpdates = true, mapGroupsFunc, KeyedStateTimeout.ProcessingTimeTimeout, + preTimeoutState, priorTimeoutTimestamp, + expectedState, expectedTimeoutTimestamp) + } + } + + def testStateUpdate( + testTimeoutUpdates: Boolean, + mapGroupsFunc: (Int, Iterator[Int], KeyedState[Int]) => Iterator[Int], + timeoutType: KeyedStateTimeout, + priorState: Option[Int], + priorTimeoutTimestamp: Long, + expectedState: Option[Int], + expectedTimeoutTimestamp: Long): Unit = { + + val store = newStateStore() + val mapGroupsSparkPlan = newFlatMapGroupsWithStateExec( + mapGroupsFunc, timeoutType, currentTimestamp) + val updater = new mapGroupsSparkPlan.StateStoreUpdater(store) + val key = intToRow(0) + // Prepare store with prior state configs + if (priorState.nonEmpty) { + val row = updater.getStateRow(priorState.get) + updater.setTimeoutTimestamp(row, priorTimeoutTimestamp) + store.put(key.copy(), row.copy()) + } + + // Call updating function to update state store + val returnedIter = if (testTimeoutUpdates) { + updater.updateStateForTimedOutKeys() + } else { + updater.updateStateForKeysWithData(Iterator(key)) + } + returnedIter.size // consumer the iterator to force state updates + + // Verify updated state in store + val updatedStateRow = store.get(key) + assert( + updater.getStateObj(updatedStateRow).map(_.toString.toInt) === expectedState, + "final state not as expected") + if (updatedStateRow.nonEmpty) { + assert( + updater.getTimeoutTimestamp(updatedStateRow.get) === expectedTimeoutTimestamp, + "final timeout timestamp not as expected") + } + } + + def newFlatMapGroupsWithStateExec( + func: (Int, Iterator[Int], KeyedState[Int]) => Iterator[Int], + timeoutType: KeyedStateTimeout = KeyedStateTimeout.NoTimeout, + batchTimestampMs: Long = NO_BATCH_PROCESSING_TIMESTAMP): FlatMapGroupsWithStateExec = { + MemoryStream[Int] + .toDS + .groupByKey(x => x) + .flatMapGroupsWithState[Int, Int](Append, timeoutConf = timeoutType)(func) + .logicalPlan.collectFirst { + case FlatMapGroupsWithState(f, k, v, g, d, o, s, m, _, t, _) => + FlatMapGroupsWithStateExec( + f, k, v, g, d, o, None, s, m, t, currentTimestamp, + RDDScanExec(g, null, "rdd")) + }.get + } + + def newStateStore(): StateStore = new MemoryStateStore() + + val intProj = UnsafeProjection.create(Array[DataType](IntegerType)) + def intToRow(i: Int): UnsafeRow = { + intProj.apply(new GenericInternalRow(Array[Any](i))).copy() + } + + def rowToInt(row: UnsafeRow): Int = row.getInt(0) } object FlatMapGroupsWithStateSuite { + var failInTask = true + + class MemoryStateStore extends StateStore() { + import scala.collection.JavaConverters._ + private val map = new ConcurrentHashMap[UnsafeRow, UnsafeRow] + + override def iterator(): Iterator[(UnsafeRow, UnsafeRow)] = { + map.entrySet.iterator.asScala.map { case e => (e.getKey, e.getValue) } + } + + override def filter(c: (UnsafeRow, UnsafeRow) => Boolean): Iterator[(UnsafeRow, UnsafeRow)] = { + iterator.filter { case (k, v) => c(k, v) } + } + + override def get(key: UnsafeRow): Option[UnsafeRow] = Option(map.get(key)) + override def put(key: UnsafeRow, newValue: UnsafeRow): Unit = map.put(key, newValue) + override def remove(key: UnsafeRow): Unit = { map.remove(key) } + override def remove(condition: (UnsafeRow) => Boolean): Unit = { + iterator.map(_._1).filter(condition).foreach(map.remove) + } + override def commit(): Long = version + 1 + override def abort(): Unit = { } + override def id: StateStoreId = null + override def version: Long = 0 + override def updates(): Iterator[StoreUpdate] = { throw new UnsupportedOperationException } + override def numKeys(): Long = map.size + override def hasCommitted: Boolean = true + } } From 0cdcf9114527a2c359c25e46fd6556b3855bfb28 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Sun, 19 Mar 2017 22:33:01 -0700 Subject: [PATCH 071/512] [SPARK-19849][SQL] Support ArrayType in to_json to produce JSON array ## What changes were proposed in this pull request? This PR proposes to support an array of struct type in `to_json` as below: ```scala import org.apache.spark.sql.functions._ val df = Seq(Tuple1(Tuple1(1) :: Nil)).toDF("a") df.select(to_json($"a").as("json")).show() ``` ``` +----------+ | json| +----------+ |[{"_1":1}]| +----------+ ``` Currently, it throws an exception as below (a newline manually inserted for readability): ``` org.apache.spark.sql.AnalysisException: cannot resolve 'structtojson(`array`)' due to data type mismatch: structtojson requires that the expression is a struct expression.;; ``` This allows the roundtrip with `from_json` as below: ```scala import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ val schema = ArrayType(StructType(StructField("a", IntegerType) :: Nil)) val df = Seq("""[{"a":1}, {"a":2}]""").toDF("json").select(from_json($"json", schema).as("array")) df.show() // Read back. df.select(to_json($"array").as("json")).show() ``` ``` +----------+ | array| +----------+ |[[1], [2]]| +----------+ +-----------------+ | json| +-----------------+ |[{"a":1},{"a":2}]| +-----------------+ ``` Also, this PR proposes to rename from `StructToJson` to `StructsToJson ` and `JsonToStruct` to `JsonToStructs`. ## How was this patch tested? Unit tests in `JsonFunctionsSuite` and `JsonExpressionsSuite` for Scala, doctest for Python and test in `test_sparkSQL.R` for R. Author: hyukjinkwon Closes #17192 from HyukjinKwon/SPARK-19849. --- R/pkg/R/functions.R | 18 ++-- R/pkg/inst/tests/testthat/test_sparkSQL.R | 4 + python/pyspark/sql/functions.py | 15 ++- .../catalyst/analysis/FunctionRegistry.scala | 4 +- .../expressions/jsonExpressions.scala | 70 +++++++++----- .../sql/catalyst/json/JacksonGenerator.scala | 23 +++-- .../expressions/JsonExpressionsSuite.scala | 77 ++++++++++----- .../org/apache/spark/sql/functions.scala | 34 ++++--- .../sql-tests/inputs/json-functions.sql | 1 + .../sql-tests/results/json-functions.sql.out | 96 ++++++++++--------- .../apache/spark/sql/JsonFunctionsSuite.scala | 26 ++++- 11 files changed, 236 insertions(+), 132 deletions(-) diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 9867f2d5b7c5..2cff3ac08c3a 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -1795,10 +1795,10 @@ setMethod("to_date", #' to_json #' -#' Converts a column containing a \code{structType} into a Column of JSON string. -#' Resolving the Column can fail if an unsupported type is encountered. +#' Converts a column containing a \code{structType} or array of \code{structType} into a Column +#' of JSON string. Resolving the Column can fail if an unsupported type is encountered. #' -#' @param x Column containing the struct +#' @param x Column containing the struct or array of the structs #' @param ... additional named properties to control how it is converted, accepts the same options #' as the JSON data source. #' @@ -1809,8 +1809,13 @@ setMethod("to_date", #' @export #' @examples #' \dontrun{ -#' to_json(df$t, dateFormat = 'dd/MM/yyyy') -#' select(df, to_json(df$t)) +#' # Converts a struct into a JSON object +#' df <- sql("SELECT named_struct('date', cast('2000-01-01' as date)) as d") +#' select(df, to_json(df$d, dateFormat = 'dd/MM/yyyy')) +#' +#' # Converts an array of structs into a JSON array +#' df <- sql("SELECT array(named_struct('name', 'Bob'), named_struct('name', 'Alice')) as people") +#' select(df, to_json(df$people)) #'} #' @note to_json since 2.2.0 setMethod("to_json", signature(x = "Column"), @@ -2433,7 +2438,8 @@ setMethod("date_format", signature(y = "Column", x = "character"), #' from_json #' #' Parses a column containing a JSON string into a Column of \code{structType} with the specified -#' \code{schema}. If the string is unparseable, the Column will contains the value NA. +#' \code{schema} or array of \code{structType} if \code{asJsonArray} is set to \code{TRUE}. +#' If the string is unparseable, the Column will contains the value NA. #' #' @param x Column containing the JSON string. #' @param schema a structType object to use as the schema to use when parsing the JSON string. diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index 32856b399cdd..9c38e0d866aa 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -1340,6 +1340,10 @@ test_that("column functions", { expect_equal(collect(select(df, bround(df$x, 0)))[[1]][2], 4) # Test to_json(), from_json() + df <- sql("SELECT array(named_struct('name', 'Bob'), named_struct('name', 'Alice')) as people") + j <- collect(select(df, alias(to_json(df$people), "json"))) + expect_equal(j[order(j$json), ][1], "[{\"name\":\"Bob\"},{\"name\":\"Alice\"}]") + df <- read.json(mapTypeJsonPath) j <- collect(select(df, alias(to_json(df$info), "json"))) expect_equal(j[order(j$json), ][1], "{\"age\":16,\"height\":176.5}") diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 376b86ea69bd..f9121e60f35b 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1774,10 +1774,11 @@ def json_tuple(col, *fields): def from_json(col, schema, options={}): """ Parses a column containing a JSON string into a [[StructType]] or [[ArrayType]] - with the specified schema. Returns `null`, in the case of an unparseable string. + of [[StructType]]s with the specified schema. Returns `null`, in the case of an unparseable + string. :param col: string column in json format - :param schema: a StructType or ArrayType to use when parsing the json column + :param schema: a StructType or ArrayType of StructType to use when parsing the json column :param options: options to control parsing. accepts the same options as the json datasource >>> from pyspark.sql.types import * @@ -1802,10 +1803,10 @@ def from_json(col, schema, options={}): @since(2.1) def to_json(col, options={}): """ - Converts a column containing a [[StructType]] into a JSON string. Throws an exception, - in the case of an unsupported type. + Converts a column containing a [[StructType]] or [[ArrayType]] of [[StructType]]s into a + JSON string. Throws an exception, in the case of an unsupported type. - :param col: name of column containing the struct + :param col: name of column containing the struct or array of the structs :param options: options to control converting. accepts the same options as the json datasource >>> from pyspark.sql import Row @@ -1814,6 +1815,10 @@ def to_json(col, options={}): >>> df = spark.createDataFrame(data, ("key", "value")) >>> df.select(to_json(df.value).alias("json")).collect() [Row(json=u'{"age":2,"name":"Alice"}')] + >>> data = [(1, [Row(name='Alice', age=2), Row(name='Bob', age=3)])] + >>> df = spark.createDataFrame(data, ("key", "value")) + >>> df.select(to_json(df.value).alias("json")).collect() + [Row(json=u'[{"age":2,"name":"Alice"},{"age":3,"name":"Bob"}]')] """ sc = SparkContext._active_spark_context diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 0486e67dbdf8..e1d83a86f99d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -425,8 +425,8 @@ object FunctionRegistry { expression[BitwiseXor]("^"), // json - expression[StructToJson]("to_json"), - expression[JsonToStruct]("from_json"), + expression[StructsToJson]("to_json"), + expression[JsonToStructs]("from_json"), // Cast aliases (SPARK-16730) castAlias("boolean", BooleanType), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index 37e4bb506043..e4e08a8665a5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.json._ -import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, ParseModes} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, ParseModes} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils @@ -482,7 +482,8 @@ case class JsonTuple(children: Seq[Expression]) } /** - * Converts an json input string to a [[StructType]] or [[ArrayType]] with the specified schema. + * Converts an json input string to a [[StructType]] or [[ArrayType]] of [[StructType]]s + * with the specified schema. */ // scalastyle:off line.size.limit @ExpressionDescription( @@ -495,7 +496,7 @@ case class JsonTuple(children: Seq[Expression]) {"time":"2015-08-26 00:00:00.0"} """) // scalastyle:on line.size.limit -case class JsonToStruct( +case class JsonToStructs( schema: DataType, options: Map[String, String], child: Expression, @@ -590,7 +591,7 @@ case class JsonToStruct( } /** - * Converts a [[StructType]] to a json output string. + * Converts a [[StructType]] or [[ArrayType]] of [[StructType]]s to a json output string. */ // scalastyle:off line.size.limit @ExpressionDescription( @@ -601,9 +602,11 @@ case class JsonToStruct( {"a":1,"b":2} > SELECT _FUNC_(named_struct('time', to_timestamp('2015-08-26', 'yyyy-MM-dd')), map('timestampFormat', 'dd/MM/yyyy')); {"time":"26/08/2015"} + > SELECT _FUNC_(array(named_struct('a', 1, 'b', 2)); + [{"a":1,"b":2}] """) // scalastyle:on line.size.limit -case class StructToJson( +case class StructsToJson( options: Map[String, String], child: Expression, timeZoneId: Option[String] = None) @@ -624,41 +627,58 @@ case class StructToJson( lazy val writer = new CharArrayWriter() @transient - lazy val gen = - new JacksonGenerator( - child.dataType.asInstanceOf[StructType], - writer, - new JSONOptions(options, timeZoneId.get)) + lazy val gen = new JacksonGenerator( + rowSchema, writer, new JSONOptions(options, timeZoneId.get)) + + @transient + lazy val rowSchema = child.dataType match { + case st: StructType => st + case ArrayType(st: StructType, _) => st + } + + // This converts rows to the JSON output according to the given schema. + @transient + lazy val converter: Any => UTF8String = { + def getAndReset(): UTF8String = { + gen.flush() + val json = writer.toString + writer.reset() + UTF8String.fromString(json) + } + + child.dataType match { + case _: StructType => + (row: Any) => + gen.write(row.asInstanceOf[InternalRow]) + getAndReset() + case ArrayType(_: StructType, _) => + (arr: Any) => + gen.write(arr.asInstanceOf[ArrayData]) + getAndReset() + } + } override def dataType: DataType = StringType - override def checkInputDataTypes(): TypeCheckResult = { - if (StructType.acceptsType(child.dataType)) { + override def checkInputDataTypes(): TypeCheckResult = child.dataType match { + case _: StructType | ArrayType(_: StructType, _) => try { - JacksonUtils.verifySchema(child.dataType.asInstanceOf[StructType]) + JacksonUtils.verifySchema(rowSchema) TypeCheckResult.TypeCheckSuccess } catch { case e: UnsupportedOperationException => TypeCheckResult.TypeCheckFailure(e.getMessage) } - } else { - TypeCheckResult.TypeCheckFailure( - s"$prettyName requires that the expression is a struct expression.") - } + case _ => TypeCheckResult.TypeCheckFailure( + s"Input type ${child.dataType.simpleString} must be a struct or array of structs.") } override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = copy(timeZoneId = Option(timeZoneId)) - override def nullSafeEval(row: Any): Any = { - gen.write(row.asInstanceOf[InternalRow]) - gen.flush() - val json = writer.toString - writer.reset() - UTF8String.fromString(json) - } + override def nullSafeEval(value: Any): Any = converter(value) - override def inputTypes: Seq[AbstractDataType] = StructType :: Nil + override def inputTypes: Seq[AbstractDataType] = TypeCollection(ArrayType, StructType) :: Nil } object JsonExprUtils { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala index dec55279c9fc..1d302aea6fd1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala @@ -37,6 +37,10 @@ private[sql] class JacksonGenerator( // `ValueWriter`s for all fields of the schema private val rootFieldWriters: Array[ValueWriter] = schema.map(_.dataType).map(makeWriter).toArray + // `ValueWriter` for array data storing rows of the schema. + private val arrElementWriter: ValueWriter = (arr: SpecializedGetters, i: Int) => { + writeObject(writeFields(arr.getStruct(i, schema.length), schema, rootFieldWriters)) + } private val gen = new JsonFactory().createGenerator(writer).setRootValueSeparator(null) @@ -185,17 +189,18 @@ private[sql] class JacksonGenerator( def flush(): Unit = gen.flush() /** - * Transforms a single InternalRow to JSON using Jackson + * Transforms a single `InternalRow` to JSON object using Jackson * * @param row The row to convert */ - def write(row: InternalRow): Unit = { - writeObject { - writeFields(row, schema, rootFieldWriters) - } - } + def write(row: InternalRow): Unit = writeObject(writeFields(row, schema, rootFieldWriters)) - def writeLineEnding(): Unit = { - gen.writeRaw('\n') - } + /** + * Transforms multiple `InternalRow`s to JSON array using Jackson + * + * @param array The array of rows to convert + */ + def write(array: ArrayData): Unit = writeArray(writeArrayData(array, arrElementWriter)) + + def writeLineEnding(): Unit = gen.writeRaw('\n') } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala index 19d0c8eb92f1..e4698d44636b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala @@ -21,7 +21,7 @@ import java.util.Calendar import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.util.{DateTimeTestUtils, DateTimeUtils, ParseModes} +import org.apache.spark.sql.catalyst.util.{DateTimeTestUtils, DateTimeUtils, GenericArrayData, ParseModes} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -352,7 +352,7 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val jsonData = """{"a": 1}""" val schema = StructType(StructField("a", IntegerType) :: Nil) checkEvaluation( - JsonToStruct(schema, Map.empty, Literal(jsonData), gmtId), + JsonToStructs(schema, Map.empty, Literal(jsonData), gmtId), InternalRow(1) ) } @@ -361,13 +361,13 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val jsonData = """{"a" 1}""" val schema = StructType(StructField("a", IntegerType) :: Nil) checkEvaluation( - JsonToStruct(schema, Map.empty, Literal(jsonData), gmtId), + JsonToStructs(schema, Map.empty, Literal(jsonData), gmtId), null ) // Other modes should still return `null`. checkEvaluation( - JsonToStruct(schema, Map("mode" -> ParseModes.PERMISSIVE_MODE), Literal(jsonData), gmtId), + JsonToStructs(schema, Map("mode" -> ParseModes.PERMISSIVE_MODE), Literal(jsonData), gmtId), null ) } @@ -376,62 +376,62 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val input = """[{"a": 1}, {"a": 2}]""" val schema = ArrayType(StructType(StructField("a", IntegerType) :: Nil)) val output = InternalRow(1) :: InternalRow(2) :: Nil - checkEvaluation(JsonToStruct(schema, Map.empty, Literal(input), gmtId), output) + checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId), output) } test("from_json - input=object, schema=array, output=array of single row") { val input = """{"a": 1}""" val schema = ArrayType(StructType(StructField("a", IntegerType) :: Nil)) val output = InternalRow(1) :: Nil - checkEvaluation(JsonToStruct(schema, Map.empty, Literal(input), gmtId), output) + checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId), output) } test("from_json - input=empty array, schema=array, output=empty array") { val input = "[ ]" val schema = ArrayType(StructType(StructField("a", IntegerType) :: Nil)) val output = Nil - checkEvaluation(JsonToStruct(schema, Map.empty, Literal(input), gmtId), output) + checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId), output) } test("from_json - input=empty object, schema=array, output=array of single row with null") { val input = "{ }" val schema = ArrayType(StructType(StructField("a", IntegerType) :: Nil)) val output = InternalRow(null) :: Nil - checkEvaluation(JsonToStruct(schema, Map.empty, Literal(input), gmtId), output) + checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId), output) } test("from_json - input=array of single object, schema=struct, output=single row") { val input = """[{"a": 1}]""" val schema = StructType(StructField("a", IntegerType) :: Nil) val output = InternalRow(1) - checkEvaluation(JsonToStruct(schema, Map.empty, Literal(input), gmtId), output) + checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId), output) } test("from_json - input=array, schema=struct, output=null") { val input = """[{"a": 1}, {"a": 2}]""" val schema = StructType(StructField("a", IntegerType) :: Nil) val output = null - checkEvaluation(JsonToStruct(schema, Map.empty, Literal(input), gmtId), output) + checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId), output) } test("from_json - input=empty array, schema=struct, output=null") { val input = """[]""" val schema = StructType(StructField("a", IntegerType) :: Nil) val output = null - checkEvaluation(JsonToStruct(schema, Map.empty, Literal(input), gmtId), output) + checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId), output) } test("from_json - input=empty object, schema=struct, output=single row with null") { val input = """{ }""" val schema = StructType(StructField("a", IntegerType) :: Nil) val output = InternalRow(null) - checkEvaluation(JsonToStruct(schema, Map.empty, Literal(input), gmtId), output) + checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId), output) } test("from_json null input column") { val schema = StructType(StructField("a", IntegerType) :: Nil) checkEvaluation( - JsonToStruct(schema, Map.empty, Literal.create(null, StringType), gmtId), + JsonToStructs(schema, Map.empty, Literal.create(null, StringType), gmtId), null ) } @@ -444,14 +444,14 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { c.set(2016, 0, 1, 0, 0, 0) c.set(Calendar.MILLISECOND, 123) checkEvaluation( - JsonToStruct(schema, Map.empty, Literal(jsonData1), gmtId), + JsonToStructs(schema, Map.empty, Literal(jsonData1), gmtId), InternalRow(c.getTimeInMillis * 1000L) ) // The result doesn't change because the json string includes timezone string ("Z" here), // which means the string represents the timestamp string in the timezone regardless of // the timeZoneId parameter. checkEvaluation( - JsonToStruct(schema, Map.empty, Literal(jsonData1), Option("PST")), + JsonToStructs(schema, Map.empty, Literal(jsonData1), Option("PST")), InternalRow(c.getTimeInMillis * 1000L) ) @@ -461,7 +461,7 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { c.set(2016, 0, 1, 0, 0, 0) c.set(Calendar.MILLISECOND, 0) checkEvaluation( - JsonToStruct( + JsonToStructs( schema, Map("timestampFormat" -> "yyyy-MM-dd'T'HH:mm:ss"), Literal(jsonData2), @@ -469,7 +469,7 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { InternalRow(c.getTimeInMillis * 1000L) ) checkEvaluation( - JsonToStruct( + JsonToStructs( schema, Map("timestampFormat" -> "yyyy-MM-dd'T'HH:mm:ss", DateTimeUtils.TIMEZONE_OPTION -> tz.getID), @@ -483,25 +483,52 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { test("SPARK-19543: from_json empty input column") { val schema = StructType(StructField("a", IntegerType) :: Nil) checkEvaluation( - JsonToStruct(schema, Map.empty, Literal.create(" ", StringType), gmtId), + JsonToStructs(schema, Map.empty, Literal.create(" ", StringType), gmtId), null ) } - test("to_json") { + test("to_json - struct") { val schema = StructType(StructField("a", IntegerType) :: Nil) val struct = Literal.create(create_row(1), schema) checkEvaluation( - StructToJson(Map.empty, struct, gmtId), + StructsToJson(Map.empty, struct, gmtId), """{"a":1}""" ) } + test("to_json - array") { + val inputSchema = ArrayType(StructType(StructField("a", IntegerType) :: Nil)) + val input = new GenericArrayData(InternalRow(1) :: InternalRow(2) :: Nil) + val output = """[{"a":1},{"a":2}]""" + checkEvaluation( + StructsToJson(Map.empty, Literal.create(input, inputSchema), gmtId), + output) + } + + test("to_json - array with single empty row") { + val inputSchema = ArrayType(StructType(StructField("a", IntegerType) :: Nil)) + val input = new GenericArrayData(InternalRow(null) :: Nil) + val output = """[{}]""" + checkEvaluation( + StructsToJson(Map.empty, Literal.create(input, inputSchema), gmtId), + output) + } + + test("to_json - empty array") { + val inputSchema = ArrayType(StructType(StructField("a", IntegerType) :: Nil)) + val input = new GenericArrayData(Nil) + val output = """[]""" + checkEvaluation( + StructsToJson(Map.empty, Literal.create(input, inputSchema), gmtId), + output) + } + test("to_json null input column") { val schema = StructType(StructField("a", IntegerType) :: Nil) val struct = Literal.create(null, schema) checkEvaluation( - StructToJson(Map.empty, struct, gmtId), + StructsToJson(Map.empty, struct, gmtId), null ) } @@ -514,16 +541,16 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val struct = Literal.create(create_row(c.getTimeInMillis * 1000L), schema) checkEvaluation( - StructToJson(Map.empty, struct, gmtId), + StructsToJson(Map.empty, struct, gmtId), """{"t":"2016-01-01T00:00:00.000Z"}""" ) checkEvaluation( - StructToJson(Map.empty, struct, Option("PST")), + StructsToJson(Map.empty, struct, Option("PST")), """{"t":"2015-12-31T16:00:00.000-08:00"}""" ) checkEvaluation( - StructToJson( + StructsToJson( Map("timestampFormat" -> "yyyy-MM-dd'T'HH:mm:ss", DateTimeUtils.TIMEZONE_OPTION -> gmtId.get), struct, @@ -531,7 +558,7 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { """{"t":"2016-01-01T00:00:00"}""" ) checkEvaluation( - StructToJson( + StructsToJson( Map("timestampFormat" -> "yyyy-MM-dd'T'HH:mm:ss", DateTimeUtils.TIMEZONE_OPTION -> "PST"), struct, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 201f726db3fa..a9f089c850d4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -2978,7 +2978,8 @@ object functions { /** * (Scala-specific) Parses a column containing a JSON string into a `StructType` or `ArrayType` - * with the specified schema. Returns `null`, in the case of an unparseable string. + * of `StructType`s with the specified schema. Returns `null`, in the case of an unparseable + * string. * * @param e a string column containing JSON data. * @param schema the schema to use when parsing the json string @@ -2989,7 +2990,7 @@ object functions { * @since 2.2.0 */ def from_json(e: Column, schema: DataType, options: Map[String, String]): Column = withExpr { - JsonToStruct(schema, options, e.expr) + JsonToStructs(schema, options, e.expr) } /** @@ -3009,7 +3010,8 @@ object functions { /** * (Java-specific) Parses a column containing a JSON string into a `StructType` or `ArrayType` - * with the specified schema. Returns `null`, in the case of an unparseable string. + * of `StructType`s with the specified schema. Returns `null`, in the case of an unparseable + * string. * * @param e a string column containing JSON data. * @param schema the schema to use when parsing the json string @@ -3036,7 +3038,7 @@ object functions { from_json(e, schema, Map.empty[String, String]) /** - * Parses a column containing a JSON string into a `StructType` or `ArrayType` + * Parses a column containing a JSON string into a `StructType` or `ArrayType` of `StructType`s * with the specified schema. Returns `null`, in the case of an unparseable string. * * @param e a string column containing JSON data. @@ -3049,7 +3051,7 @@ object functions { from_json(e, schema, Map.empty[String, String]) /** - * Parses a column containing a JSON string into a `StructType` or `ArrayType` + * Parses a column containing a JSON string into a `StructType` or `ArrayType` of `StructType`s * with the specified schema. Returns `null`, in the case of an unparseable string. * * @param e a string column containing JSON data. @@ -3062,10 +3064,11 @@ object functions { from_json(e, DataType.fromJson(schema), options) /** - * (Scala-specific) Converts a column containing a `StructType` into a JSON string with the - * specified schema. Throws an exception, in the case of an unsupported type. + * (Scala-specific) Converts a column containing a `StructType` or `ArrayType` of `StructType`s + * into a JSON string with the specified schema. Throws an exception, in the case of an + * unsupported type. * - * @param e a struct column. + * @param e a column containing a struct or array of the structs. * @param options options to control how the struct column is converted into a json string. * accepts the same options and the json data source. * @@ -3073,14 +3076,15 @@ object functions { * @since 2.1.0 */ def to_json(e: Column, options: Map[String, String]): Column = withExpr { - StructToJson(options, e.expr) + StructsToJson(options, e.expr) } /** - * (Java-specific) Converts a column containing a `StructType` into a JSON string with the - * specified schema. Throws an exception, in the case of an unsupported type. + * (Java-specific) Converts a column containing a `StructType` or `ArrayType` of `StructType`s + * into a JSON string with the specified schema. Throws an exception, in the case of an + * unsupported type. * - * @param e a struct column. + * @param e a column containing a struct or array of the structs. * @param options options to control how the struct column is converted into a json string. * accepts the same options and the json data source. * @@ -3091,10 +3095,10 @@ object functions { to_json(e, options.asScala.toMap) /** - * Converts a column containing a `StructType` into a JSON string with the - * specified schema. Throws an exception, in the case of an unsupported type. + * Converts a column containing a `StructType` or `ArrayType` of `StructType`s into a JSON string + * with the specified schema. Throws an exception, in the case of an unsupported type. * - * @param e a struct column. + * @param e a column containing a struct or array of the structs. * * @group collection_funcs * @since 2.1.0 diff --git a/sql/core/src/test/resources/sql-tests/inputs/json-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/json-functions.sql index 83243c5e5a12..b3cc2cea51d4 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/json-functions.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/json-functions.sql @@ -3,6 +3,7 @@ describe function to_json; describe function extended to_json; select to_json(named_struct('a', 1, 'b', 2)); select to_json(named_struct('time', to_timestamp('2015-08-26', 'yyyy-MM-dd')), map('timestampFormat', 'dd/MM/yyyy')); +select to_json(array(named_struct('a', 1, 'b', 2))); -- Check if errors handled select to_json(named_struct('a', 1, 'b', 2), named_struct('mode', 'PERMISSIVE')); select to_json(named_struct('a', 1, 'b', 2), map('mode', 1)); diff --git a/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out index b57cbbc1d843..315e1730ce7d 100644 --- a/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 16 +-- Number of queries: 17 -- !query 0 @@ -7,7 +7,7 @@ describe function to_json -- !query 0 schema struct -- !query 0 output -Class: org.apache.spark.sql.catalyst.expressions.StructToJson +Class: org.apache.spark.sql.catalyst.expressions.StructsToJson Function: to_json Usage: to_json(expr[, options]) - Returns a json string with a given struct value @@ -17,13 +17,15 @@ describe function extended to_json -- !query 1 schema struct -- !query 1 output -Class: org.apache.spark.sql.catalyst.expressions.StructToJson +Class: org.apache.spark.sql.catalyst.expressions.StructsToJson Extended Usage: Examples: > SELECT to_json(named_struct('a', 1, 'b', 2)); {"a":1,"b":2} > SELECT to_json(named_struct('time', to_timestamp('2015-08-26', 'yyyy-MM-dd')), map('timestampFormat', 'dd/MM/yyyy')); {"time":"26/08/2015"} + > SELECT to_json(array(named_struct('a', 1, 'b', 2)); + [{"a":1,"b":2}] Function: to_json Usage: to_json(expr[, options]) - Returns a json string with a given struct value @@ -32,7 +34,7 @@ Usage: to_json(expr[, options]) - Returns a json string with a given struct valu -- !query 2 select to_json(named_struct('a', 1, 'b', 2)) -- !query 2 schema -struct +struct -- !query 2 output {"a":1,"b":2} @@ -40,54 +42,62 @@ struct -- !query 3 select to_json(named_struct('time', to_timestamp('2015-08-26', 'yyyy-MM-dd')), map('timestampFormat', 'dd/MM/yyyy')) -- !query 3 schema -struct +struct -- !query 3 output {"time":"26/08/2015"} -- !query 4 -select to_json(named_struct('a', 1, 'b', 2), named_struct('mode', 'PERMISSIVE')) +select to_json(array(named_struct('a', 1, 'b', 2))) -- !query 4 schema -struct<> +struct -- !query 4 output -org.apache.spark.sql.AnalysisException -Must use a map() function for options;; line 1 pos 7 +[{"a":1,"b":2}] -- !query 5 -select to_json(named_struct('a', 1, 'b', 2), map('mode', 1)) +select to_json(named_struct('a', 1, 'b', 2), named_struct('mode', 'PERMISSIVE')) -- !query 5 schema struct<> -- !query 5 output org.apache.spark.sql.AnalysisException -A type of keys and values in map() must be string, but got MapType(StringType,IntegerType,false);; line 1 pos 7 +Must use a map() function for options;; line 1 pos 7 -- !query 6 -select to_json() +select to_json(named_struct('a', 1, 'b', 2), map('mode', 1)) -- !query 6 schema struct<> -- !query 6 output org.apache.spark.sql.AnalysisException -Invalid number of arguments for function to_json; line 1 pos 7 +A type of keys and values in map() must be string, but got MapType(StringType,IntegerType,false);; line 1 pos 7 -- !query 7 -describe function from_json +select to_json() -- !query 7 schema -struct +struct<> -- !query 7 output -Class: org.apache.spark.sql.catalyst.expressions.JsonToStruct -Function: from_json -Usage: from_json(jsonStr, schema[, options]) - Returns a struct value with the given `jsonStr` and `schema`. +org.apache.spark.sql.AnalysisException +Invalid number of arguments for function to_json; line 1 pos 7 -- !query 8 -describe function extended from_json +describe function from_json -- !query 8 schema struct -- !query 8 output -Class: org.apache.spark.sql.catalyst.expressions.JsonToStruct +Class: org.apache.spark.sql.catalyst.expressions.JsonToStructs +Function: from_json +Usage: from_json(jsonStr, schema[, options]) - Returns a struct value with the given `jsonStr` and `schema`. + + +-- !query 9 +describe function extended from_json +-- !query 9 schema +struct +-- !query 9 output +Class: org.apache.spark.sql.catalyst.expressions.JsonToStructs Extended Usage: Examples: > SELECT from_json('{"a":1, "b":0.8}', 'a INT, b DOUBLE'); @@ -99,36 +109,36 @@ Function: from_json Usage: from_json(jsonStr, schema[, options]) - Returns a struct value with the given `jsonStr` and `schema`. --- !query 9 +-- !query 10 select from_json('{"a":1}', 'a INT') --- !query 9 schema -struct> --- !query 9 output +-- !query 10 schema +struct> +-- !query 10 output {"a":1} --- !query 10 +-- !query 11 select from_json('{"time":"26/08/2015"}', 'time Timestamp', map('timestampFormat', 'dd/MM/yyyy')) --- !query 10 schema -struct> --- !query 10 output +-- !query 11 schema +struct> +-- !query 11 output {"time":2015-08-26 00:00:00.0} --- !query 11 +-- !query 12 select from_json('{"a":1}', 1) --- !query 11 schema +-- !query 12 schema struct<> --- !query 11 output +-- !query 12 output org.apache.spark.sql.AnalysisException Expected a string literal instead of 1;; line 1 pos 7 --- !query 12 +-- !query 13 select from_json('{"a":1}', 'a InvalidType') --- !query 12 schema +-- !query 13 schema struct<> --- !query 12 output +-- !query 13 output org.apache.spark.sql.AnalysisException DataType invalidtype() is not supported.(line 1, pos 2) @@ -139,28 +149,28 @@ a InvalidType ; line 1 pos 7 --- !query 13 +-- !query 14 select from_json('{"a":1}', 'a INT', named_struct('mode', 'PERMISSIVE')) --- !query 13 schema +-- !query 14 schema struct<> --- !query 13 output +-- !query 14 output org.apache.spark.sql.AnalysisException Must use a map() function for options;; line 1 pos 7 --- !query 14 +-- !query 15 select from_json('{"a":1}', 'a INT', map('mode', 1)) --- !query 14 schema +-- !query 15 schema struct<> --- !query 14 output +-- !query 15 output org.apache.spark.sql.AnalysisException A type of keys and values in map() must be string, but got MapType(StringType,IntegerType,false);; line 1 pos 7 --- !query 15 +-- !query 16 select from_json() --- !query 15 schema +-- !query 16 schema struct<> --- !query 15 output +-- !query 16 output org.apache.spark.sql.AnalysisException Invalid number of arguments for function from_json; line 1 pos 7 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala index 2345b8208116..170c238c5343 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala @@ -156,7 +156,7 @@ class JsonFunctionsSuite extends QueryTest with SharedSQLContext { Row(Seq(Row(1, "a"), Row(2, null), Row(null, null)))) } - test("to_json") { + test("to_json - struct") { val df = Seq(Tuple1(Tuple1(1))).toDF("a") checkAnswer( @@ -164,6 +164,14 @@ class JsonFunctionsSuite extends QueryTest with SharedSQLContext { Row("""{"_1":1}""") :: Nil) } + test("to_json - array") { + val df = Seq(Tuple1(Tuple1(1) :: Nil)).toDF("a") + + checkAnswer( + df.select(to_json($"a")), + Row("""[{"_1":1}]""") :: Nil) + } + test("to_json with option") { val df = Seq(Tuple1(Tuple1(java.sql.Timestamp.valueOf("2015-08-26 18:00:00.0")))).toDF("a") val options = Map("timestampFormat" -> "dd/MM/yyyy HH:mm") @@ -184,7 +192,7 @@ class JsonFunctionsSuite extends QueryTest with SharedSQLContext { "Unable to convert column a of type calendarinterval to JSON.")) } - test("roundtrip in to_json and from_json") { + test("roundtrip in to_json and from_json - struct") { val dfOne = Seq(Tuple1(Tuple1(1)), Tuple1(null)).toDF("struct") val schemaOne = dfOne.schema(0).dataType.asInstanceOf[StructType] val readBackOne = dfOne.select(to_json($"struct").as("json")) @@ -198,6 +206,20 @@ class JsonFunctionsSuite extends QueryTest with SharedSQLContext { checkAnswer(dfTwo, readBackTwo) } + test("roundtrip in to_json and from_json - array") { + val dfOne = Seq(Tuple1(Tuple1(1) :: Nil), Tuple1(null :: Nil)).toDF("array") + val schemaOne = dfOne.schema(0).dataType + val readBackOne = dfOne.select(to_json($"array").as("json")) + .select(from_json($"json", schemaOne).as("array")) + checkAnswer(dfOne, readBackOne) + + val dfTwo = Seq(Some("""[{"a":1}]"""), None).toDF("json") + val schemaTwo = ArrayType(StructType(StructField("a", IntegerType) :: Nil)) + val readBackTwo = dfTwo.select(from_json($"json", schemaTwo).as("array")) + .select(to_json($"array").as("json")) + checkAnswer(dfTwo, readBackTwo) + } + test("SPARK-19637 Support to_json in SQL") { val df1 = Seq(Tuple1(Tuple1(1))).toDF("a") checkAnswer( From c40597720e8e66a6b11ca241b1ad387154a8fe72 Mon Sep 17 00:00:00 2001 From: Felix Cheung Date: Sun, 19 Mar 2017 22:34:18 -0700 Subject: [PATCH 072/512] [SPARK-20020][SPARKR] DataFrame checkpoint API ## What changes were proposed in this pull request? Add checkpoint, setCheckpointDir API to R ## How was this patch tested? unit tests, manual tests Author: Felix Cheung Closes #17351 from felixcheung/rdfcheckpoint. --- R/pkg/NAMESPACE | 2 ++ R/pkg/R/DataFrame.R | 29 +++++++++++++++++++++++ R/pkg/R/RDD.R | 2 +- R/pkg/R/context.R | 21 +++++++++++++++- R/pkg/R/generics.R | 6 ++++- R/pkg/inst/tests/testthat/test_rdd.R | 4 ++-- R/pkg/inst/tests/testthat/test_sparkSQL.R | 11 +++++++++ 7 files changed, 70 insertions(+), 5 deletions(-) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 78344ce9ff08..8be7875ad2d5 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -82,6 +82,7 @@ exportMethods("arrange", "as.data.frame", "attach", "cache", + "checkpoint", "coalesce", "collect", "colnames", @@ -369,6 +370,7 @@ export("as.DataFrame", "read.parquet", "read.stream", "read.text", + "setCheckpointDir", "spark.lapply", "spark.addFile", "spark.getSparkFilesRootDirectory", diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index bc81633815c6..97786df4ae6a 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -3613,3 +3613,32 @@ setMethod("write.stream", ssq <- handledCallJMethod(write, "start") streamingQuery(ssq) }) + +#' checkpoint +#' +#' Returns a checkpointed version of this SparkDataFrame. Checkpointing can be used to truncate the +#' logical plan, which is especially useful in iterative algorithms where the plan may grow +#' exponentially. It will be saved to files inside the checkpoint directory set with +#' \code{setCheckpointDir} +#' +#' @param x A SparkDataFrame +#' @param eager whether to checkpoint this SparkDataFrame immediately +#' @return a new checkpointed SparkDataFrame +#' @family SparkDataFrame functions +#' @aliases checkpoint,SparkDataFrame-method +#' @rdname checkpoint +#' @name checkpoint +#' @seealso \link{setCheckpointDir} +#' @export +#' @examples +#'\dontrun{ +#' setCheckpointDir("/checkpoint") +#' df <- checkpoint(df) +#' } +#' @note checkpoint since 2.2.0 +setMethod("checkpoint", + signature(x = "SparkDataFrame"), + function(x, eager = TRUE) { + df <- callJMethod(x@sdf, "checkpoint", as.logical(eager)) + dataFrame(df) + }) diff --git a/R/pkg/R/RDD.R b/R/pkg/R/RDD.R index 5667b9d78882..7ad3993e9ecb 100644 --- a/R/pkg/R/RDD.R +++ b/R/pkg/R/RDD.R @@ -291,7 +291,7 @@ setMethod("unpersistRDD", #' @rdname checkpoint-methods #' @aliases checkpoint,RDD-method #' @noRd -setMethod("checkpoint", +setMethod("checkpointRDD", signature(x = "RDD"), function(x) { jrdd <- getJRDD(x) diff --git a/R/pkg/R/context.R b/R/pkg/R/context.R index 1a0dd65f450b..cb0f83b2fa22 100644 --- a/R/pkg/R/context.R +++ b/R/pkg/R/context.R @@ -291,7 +291,7 @@ broadcast <- function(sc, object) { #' rdd <- parallelize(sc, 1:2, 2L) #' checkpoint(rdd) #'} -setCheckpointDir <- function(sc, dirName) { +setCheckpointDirSC <- function(sc, dirName) { invisible(callJMethod(sc, "setCheckpointDir", suppressWarnings(normalizePath(dirName)))) } @@ -410,3 +410,22 @@ setLogLevel <- function(level) { sc <- getSparkContext() invisible(callJMethod(sc, "setLogLevel", level)) } + +#' Set checkpoint directory +#' +#' Set the directory under which SparkDataFrame are going to be checkpointed. The directory must be +#' a HDFS path if running on a cluster. +#' +#' @rdname setCheckpointDir +#' @param directory Directory path to checkpoint to +#' @seealso \link{checkpoint} +#' @export +#' @examples +#'\dontrun{ +#' setCheckpointDir("/checkpoint") +#'} +#' @note setCheckpointDir since 2.0.0 +setCheckpointDir <- function(directory) { + sc <- getSparkContext() + invisible(callJMethod(sc, "setCheckpointDir", suppressWarnings(normalizePath(directory)))) +} diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 029771289fd5..80283e48ced7 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -32,7 +32,7 @@ setGeneric("coalesceRDD", function(x, numPartitions, ...) { standardGeneric("coa # @rdname checkpoint-methods # @export -setGeneric("checkpoint", function(x) { standardGeneric("checkpoint") }) +setGeneric("checkpointRDD", function(x) { standardGeneric("checkpointRDD") }) setGeneric("collectRDD", function(x, ...) { standardGeneric("collectRDD") }) @@ -406,6 +406,10 @@ setGeneric("attach") #' @export setGeneric("cache", function(x) { standardGeneric("cache") }) +#' @rdname checkpoint +#' @export +setGeneric("checkpoint", function(x, eager = TRUE) { standardGeneric("checkpoint") }) + #' @rdname coalesce #' @param x a Column or a SparkDataFrame. #' @param ... additional argument(s). If \code{x} is a Column, additional Columns can be optionally diff --git a/R/pkg/inst/tests/testthat/test_rdd.R b/R/pkg/inst/tests/testthat/test_rdd.R index 787ef51c501c..b72c801dd958 100644 --- a/R/pkg/inst/tests/testthat/test_rdd.R +++ b/R/pkg/inst/tests/testthat/test_rdd.R @@ -143,8 +143,8 @@ test_that("PipelinedRDD support actions: cache(), persist(), unpersist(), checkp expect_false(rdd2@env$isCached) tempDir <- tempfile(pattern = "checkpoint") - setCheckpointDir(sc, tempDir) - checkpoint(rdd2) + setCheckpointDirSC(sc, tempDir) + checkpointRDD(rdd2) expect_true(rdd2@env$isCheckpointed) rdd2 <- lapply(rdd2, function(x) x) diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index 9c38e0d866aa..cbc3569795d9 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -841,6 +841,17 @@ test_that("cache(), storageLevel(), persist(), and unpersist() on a DataFrame", expect_true(is.data.frame(collect(df))) }) +test_that("setCheckpointDir(), checkpoint() on a DataFrame", { + checkpointDir <- file.path(tempdir(), "cproot") + expect_true(length(list.files(path = checkpointDir, all.files = TRUE)) == 0) + + setCheckpointDir(checkpointDir) + df <- read.json(jsonPath) + df <- checkpoint(df) + expect_is(df, "SparkDataFrame") + expect_false(length(list.files(path = checkpointDir, all.files = TRUE)) == 0) +}) + test_that("schema(), dtypes(), columns(), names() return the correct values/format", { df <- read.json(jsonPath) testSchema <- schema(df) From 965a5abcff3adccc10a53b0d97d06c43934df1a2 Mon Sep 17 00:00:00 2001 From: wangzhenhua Date: Mon, 20 Mar 2017 14:37:23 +0800 Subject: [PATCH 073/512] [SPARK-19994][SQL] Wrong outputOrdering for right/full outer smj ## What changes were proposed in this pull request? For right outer join, values of the left key will be filled with nulls if it can't match the value of the right key, so `nullOrdering` of the left key can't be guaranteed. We should output right key order instead of left key order. For full outer join, neither left key nor right key guarantees `nullOrdering`. We should not output any ordering. In tests, besides adding three test cases for left/right/full outer sort merge join, this patch also reorganizes code in `PlannerSuite` by putting together tests for `Sort`, and also extracts common logic in Sort tests into a method. ## How was this patch tested? Corresponding test cases are added. Author: wangzhenhua Author: Zhenhua Wang Closes #17331 from wzhfy/wrongOrdering. --- .../execution/joins/SortMergeJoinExec.scala | 12 +- .../spark/sql/execution/PlannerSuite.scala | 233 ++++++++++-------- 2 files changed, 146 insertions(+), 99 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index bcdc4dcdf7d9..02f4f55c7999 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -80,7 +80,17 @@ case class SortMergeJoinExec( override def requiredChildDistribution: Seq[Distribution] = ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil - override def outputOrdering: Seq[SortOrder] = requiredOrders(leftKeys) + override def outputOrdering: Seq[SortOrder] = joinType match { + // For left and right outer joins, the output is ordered by the streamed input's join keys. + case LeftOuter => requiredOrders(leftKeys) + case RightOuter => requiredOrders(rightKeys) + // There are null rows in both streams, so there is no order. + case FullOuter => Nil + case _: InnerLike | LeftExistence(_) => requiredOrders(leftKeys) + case x => + throw new IllegalArgumentException( + s"${getClass.getSimpleName} should not take $x as the JoinType") + } override def requiredChildOrdering: Seq[Seq[SortOrder]] = requiredOrders(leftKeys) :: requiredOrders(rightKeys) :: Nil diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 02ccebd22bdf..f2232fc489b7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -21,7 +21,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.{execution, Row} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.Inner +import org.apache.spark.sql.catalyst.plans.{FullOuter, Inner, LeftOuter, RightOuter} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Repartition} import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.columnar.InMemoryRelation @@ -251,7 +251,9 @@ class PlannerSuite extends SharedSQLContext { } } - // --- Unit tests of EnsureRequirements --------------------------------------------------------- + /////////////////////////////////////////////////////////////////////////// + // Unit tests of EnsureRequirements for Exchange + /////////////////////////////////////////////////////////////////////////// // When it comes to testing whether EnsureRequirements properly ensures distribution requirements, // there two dimensions that need to be considered: are the child partitionings compatible and @@ -384,93 +386,6 @@ class PlannerSuite extends SharedSQLContext { } } - test("EnsureRequirements adds sort when there is no existing ordering") { - val orderingA = SortOrder(Literal(1), Ascending) - val orderingB = SortOrder(Literal(2), Ascending) - assert(orderingA != orderingB) - val inputPlan = DummySparkPlan( - children = DummySparkPlan(outputOrdering = Seq.empty) :: Nil, - requiredChildOrdering = Seq(Seq(orderingB)), - requiredChildDistribution = Seq(UnspecifiedDistribution) - ) - val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(inputPlan) - assertDistributionRequirementsAreSatisfied(outputPlan) - if (outputPlan.collect { case s: SortExec => true }.isEmpty) { - fail(s"Sort should have been added:\n$outputPlan") - } - } - - test("EnsureRequirements skips sort when required ordering is prefix of existing ordering") { - val orderingA = SortOrder(Literal(1), Ascending) - val orderingB = SortOrder(Literal(2), Ascending) - assert(orderingA != orderingB) - val inputPlan = DummySparkPlan( - children = DummySparkPlan(outputOrdering = Seq(orderingA, orderingB)) :: Nil, - requiredChildOrdering = Seq(Seq(orderingA)), - requiredChildDistribution = Seq(UnspecifiedDistribution) - ) - val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(inputPlan) - assertDistributionRequirementsAreSatisfied(outputPlan) - if (outputPlan.collect { case s: SortExec => true }.nonEmpty) { - fail(s"No sorts should have been added:\n$outputPlan") - } - } - - test("EnsureRequirements skips sort when required ordering is semantically equal to " + - "existing ordering") { - val exprId: ExprId = NamedExpression.newExprId - val attribute1 = - AttributeReference( - name = "col1", - dataType = LongType, - nullable = false - ) (exprId = exprId, - qualifier = Some("col1_qualifier") - ) - - val attribute2 = - AttributeReference( - name = "col1", - dataType = LongType, - nullable = false - ) (exprId = exprId) - - val orderingA1 = SortOrder(attribute1, Ascending) - val orderingA2 = SortOrder(attribute2, Ascending) - - assert(orderingA1 != orderingA2, s"$orderingA1 should NOT equal to $orderingA2") - assert(orderingA1.semanticEquals(orderingA2), - s"$orderingA1 should be semantically equal to $orderingA2") - - val inputPlan = DummySparkPlan( - children = DummySparkPlan(outputOrdering = Seq(orderingA1)) :: Nil, - requiredChildOrdering = Seq(Seq(orderingA2)), - requiredChildDistribution = Seq(UnspecifiedDistribution) - ) - val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(inputPlan) - assertDistributionRequirementsAreSatisfied(outputPlan) - if (outputPlan.collect { case s: SortExec => true }.nonEmpty) { - fail(s"No sorts should have been added:\n$outputPlan") - } - } - - // This is a regression test for SPARK-11135 - test("EnsureRequirements adds sort when required ordering isn't a prefix of existing ordering") { - val orderingA = SortOrder(Literal(1), Ascending) - val orderingB = SortOrder(Literal(2), Ascending) - assert(orderingA != orderingB) - val inputPlan = DummySparkPlan( - children = DummySparkPlan(outputOrdering = Seq(orderingA)) :: Nil, - requiredChildOrdering = Seq(Seq(orderingA, orderingB)), - requiredChildDistribution = Seq(UnspecifiedDistribution) - ) - val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(inputPlan) - assertDistributionRequirementsAreSatisfied(outputPlan) - if (outputPlan.collect { case s: SortExec => true }.isEmpty) { - fail(s"Sort should have been added:\n$outputPlan") - } - } - test("EnsureRequirements eliminates Exchange if child has Exchange with same partitioning") { val distribution = ClusteredDistribution(Literal(1) :: Nil) val finalPartitioning = HashPartitioning(Literal(1) :: Nil, 5) @@ -481,7 +396,7 @@ class PlannerSuite extends SharedSQLContext { children = DummySparkPlan(outputPartitioning = childPartitioning) :: Nil, requiredChildDistribution = Seq(distribution), requiredChildOrdering = Seq(Seq.empty)), - None) + None) val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) @@ -510,8 +425,6 @@ class PlannerSuite extends SharedSQLContext { } } - // --------------------------------------------------------------------------------------------- - test("Reuse exchanges") { val distribution = ClusteredDistribution(Literal(1) :: Nil) val finalPartitioning = HashPartitioning(Literal(1) :: Nil, 5) @@ -525,12 +438,12 @@ class PlannerSuite extends SharedSQLContext { None) val inputPlan = SortMergeJoinExec( - Literal(1) :: Nil, - Literal(1) :: Nil, - Inner, - None, - shuffle, - shuffle) + Literal(1) :: Nil, + Literal(1) :: Nil, + Inner, + None, + shuffle, + shuffle) val outputPlan = ReuseExchange(spark.sessionState.conf).apply(inputPlan) if (outputPlan.collect { case e: ReusedExchangeExec => true }.size != 1) { @@ -557,6 +470,130 @@ class PlannerSuite extends SharedSQLContext { fail(s"Should have only two shuffles:\n$outputPlan") } } + + /////////////////////////////////////////////////////////////////////////// + // Unit tests of EnsureRequirements for Sort + /////////////////////////////////////////////////////////////////////////// + + private val exprA = Literal(1) + private val exprB = Literal(2) + private val orderingA = SortOrder(exprA, Ascending) + private val orderingB = SortOrder(exprB, Ascending) + private val planA = DummySparkPlan(outputOrdering = Seq(orderingA), + outputPartitioning = HashPartitioning(exprA :: Nil, 5)) + private val planB = DummySparkPlan(outputOrdering = Seq(orderingB), + outputPartitioning = HashPartitioning(exprB :: Nil, 5)) + + assert(orderingA != orderingB) + + private def assertSortRequirementsAreSatisfied( + childPlan: SparkPlan, + requiredOrdering: Seq[SortOrder], + shouldHaveSort: Boolean): Unit = { + val inputPlan = DummySparkPlan( + children = childPlan :: Nil, + requiredChildOrdering = Seq(requiredOrdering), + requiredChildDistribution = Seq(UnspecifiedDistribution) + ) + val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(inputPlan) + assertDistributionRequirementsAreSatisfied(outputPlan) + if (shouldHaveSort) { + if (outputPlan.collect { case s: SortExec => true }.isEmpty) { + fail(s"Sort should have been added:\n$outputPlan") + } + } else { + if (outputPlan.collect { case s: SortExec => true }.nonEmpty) { + fail(s"No sorts should have been added:\n$outputPlan") + } + } + } + + test("EnsureRequirements for sort operator after left outer sort merge join") { + // Only left key is sorted after left outer SMJ (thus doesn't need a sort). + val leftSmj = SortMergeJoinExec(exprA :: Nil, exprB :: Nil, LeftOuter, None, planA, planB) + Seq((orderingA, false), (orderingB, true)).foreach { case (ordering, needSort) => + assertSortRequirementsAreSatisfied( + childPlan = leftSmj, + requiredOrdering = Seq(ordering), + shouldHaveSort = needSort) + } + } + + test("EnsureRequirements for sort operator after right outer sort merge join") { + // Only right key is sorted after right outer SMJ (thus doesn't need a sort). + val rightSmj = SortMergeJoinExec(exprA :: Nil, exprB :: Nil, RightOuter, None, planA, planB) + Seq((orderingA, true), (orderingB, false)).foreach { case (ordering, needSort) => + assertSortRequirementsAreSatisfied( + childPlan = rightSmj, + requiredOrdering = Seq(ordering), + shouldHaveSort = needSort) + } + } + + test("EnsureRequirements adds sort after full outer sort merge join") { + // Neither keys is sorted after full outer SMJ, so they both need sorts. + val fullSmj = SortMergeJoinExec(exprA :: Nil, exprB :: Nil, FullOuter, None, planA, planB) + Seq(orderingA, orderingB).foreach { ordering => + assertSortRequirementsAreSatisfied( + childPlan = fullSmj, + requiredOrdering = Seq(ordering), + shouldHaveSort = true) + } + } + + test("EnsureRequirements adds sort when there is no existing ordering") { + assertSortRequirementsAreSatisfied( + childPlan = DummySparkPlan(outputOrdering = Seq.empty), + requiredOrdering = Seq(orderingB), + shouldHaveSort = true) + } + + test("EnsureRequirements skips sort when required ordering is prefix of existing ordering") { + assertSortRequirementsAreSatisfied( + childPlan = DummySparkPlan(outputOrdering = Seq(orderingA, orderingB)), + requiredOrdering = Seq(orderingA), + shouldHaveSort = false) + } + + test("EnsureRequirements skips sort when required ordering is semantically equal to " + + "existing ordering") { + val exprId: ExprId = NamedExpression.newExprId + val attribute1 = + AttributeReference( + name = "col1", + dataType = LongType, + nullable = false + ) (exprId = exprId, + qualifier = Some("col1_qualifier") + ) + + val attribute2 = + AttributeReference( + name = "col1", + dataType = LongType, + nullable = false + ) (exprId = exprId) + + val orderingA1 = SortOrder(attribute1, Ascending) + val orderingA2 = SortOrder(attribute2, Ascending) + + assert(orderingA1 != orderingA2, s"$orderingA1 should NOT equal to $orderingA2") + assert(orderingA1.semanticEquals(orderingA2), + s"$orderingA1 should be semantically equal to $orderingA2") + + assertSortRequirementsAreSatisfied( + childPlan = DummySparkPlan(outputOrdering = Seq(orderingA1)), + requiredOrdering = Seq(orderingA2), + shouldHaveSort = false) + } + + // This is a regression test for SPARK-11135 + test("EnsureRequirements adds sort when required ordering isn't a prefix of existing ordering") { + assertSortRequirementsAreSatisfied( + childPlan = DummySparkPlan(outputOrdering = Seq(orderingA)), + requiredOrdering = Seq(orderingA, orderingB), + shouldHaveSort = true) + } } // Used for unit-testing EnsureRequirements From f14f81e900e2e6c216055799584148a2c944268d Mon Sep 17 00:00:00 2001 From: Felix Cheung Date: Sun, 19 Mar 2017 23:49:26 -0700 Subject: [PATCH 074/512] [SPARK-20020][SPARKR][FOLLOWUP] DataFrame checkpoint API fix version tag ## What changes were proposed in this pull request? doc only change ## How was this patch tested? manual Author: Felix Cheung Closes #17356 from felixcheung/rdfcheckpoint2. --- R/pkg/R/context.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/pkg/R/context.R b/R/pkg/R/context.R index cb0f83b2fa22..1ca573e5bd61 100644 --- a/R/pkg/R/context.R +++ b/R/pkg/R/context.R @@ -424,7 +424,7 @@ setLogLevel <- function(level) { #'\dontrun{ #' setCheckpointDir("/checkpoint") #'} -#' @note setCheckpointDir since 2.0.0 +#' @note setCheckpointDir since 2.2.0 setCheckpointDir <- function(directory) { sc <- getSparkContext() invisible(callJMethod(sc, "setCheckpointDir", suppressWarnings(normalizePath(directory)))) From 81639115947a13017d1637549a8f66ba599b27b8 Mon Sep 17 00:00:00 2001 From: Ioana Delaney Date: Mon, 20 Mar 2017 16:04:58 +0800 Subject: [PATCH 075/512] [SPARK-17791][SQL] Join reordering using star schema detection ## What changes were proposed in this pull request? Star schema consists of one or more fact tables referencing a number of dimension tables. In general, queries against star schema are expected to run fast because of the established RI constraints among the tables. This design proposes a join reordering based on natural, generally accepted heuristics for star schema queries: - Finds the star join with the largest fact table and places it on the driving arm of the left-deep join. This plan avoids large tables on the inner, and thus favors hash joins. - Applies the most selective dimensions early in the plan to reduce the amount of data flow. The design document was included in SPARK-17791. Link to the google doc: [StarSchemaDetection](https://docs.google.com/document/d/1UAfwbm_A6wo7goHlVZfYK99pqDMEZUumi7pubJXETEA/edit?usp=sharing) ## How was this patch tested? A new test suite StarJoinSuite.scala was implemented. Author: Ioana Delaney Closes #15363 from ioana-delaney/starJoinReord2. --- .../sql/catalyst/SimpleCatalystConf.scala | 1 + .../optimizer/CostBasedJoinReorder.scala | 2 + .../sql/catalyst/optimizer/Optimizer.scala | 2 +- .../spark/sql/catalyst/optimizer/joins.scala | 350 ++++++++++- .../sql/catalyst/planning/patterns.scala | 4 +- .../apache/spark/sql/internal/SQLConf.scala | 16 + .../optimizer/JoinOptimizationSuite.scala | 4 +- .../catalyst/optimizer/JoinReorderSuite.scala | 29 +- .../optimizer/StarJoinReorderSuite.scala | 580 ++++++++++++++++++ .../spark/sql/catalyst/plans/PlanTest.scala | 26 + 10 files changed, 978 insertions(+), 36 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinReorderSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SimpleCatalystConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SimpleCatalystConf.scala index 0d4903e03bf5..ac97987c55e0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SimpleCatalystConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SimpleCatalystConf.scala @@ -40,6 +40,7 @@ case class SimpleCatalystConf( override val cboEnabled: Boolean = false, override val joinReorderEnabled: Boolean = false, override val joinReorderDPThreshold: Int = 12, + override val starSchemaDetection: Boolean = false, override val warehousePath: String = "/user/hive/warehouse", override val sessionLocalTimeZone: String = TimeZone.getDefault().getID, override val maxNestedViewDepth: Int = 100) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala index 1b32bda72bc9..521c468fe18a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala @@ -53,6 +53,8 @@ case class CostBasedJoinReorder(conf: SQLConf) extends Rule[LogicalPlan] with Pr def reorder(plan: LogicalPlan, output: AttributeSet): LogicalPlan = { val (items, conditions) = extractInnerJoins(plan) + // TODO: Compute the set of star-joins and use them in the join enumeration + // algorithm to prune un-optimal plan choices. val result = // Do reordering if the number of items is appropriate and join conditions exist. // We also need to check if costs of all items can be evaluated. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index c8ed4190a13a..d7524a57adbc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -82,7 +82,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf) Batch("Operator Optimizations", fixedPoint, // Operator push down PushProjectionThroughUnion, - ReorderJoin, + ReorderJoin(conf), EliminateOuterJoin, PushPredicateThroughJoin, PushDownPredicate, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala index bfe529e21e9a..58e4a230f4ef 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala @@ -20,19 +20,347 @@ package org.apache.spark.sql.catalyst.optimizer import scala.annotation.tailrec import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.planning.ExtractFiltersAndInnerJoins +import org.apache.spark.sql.catalyst.planning.{ExtractFiltersAndInnerJoins, PhysicalOperation} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ +import org.apache.spark.sql.internal.SQLConf + +/** + * Encapsulates star-schema join detection. + */ +case class StarSchemaDetection(conf: SQLConf) extends PredicateHelper { + + /** + * Star schema consists of one or more fact tables referencing a number of dimension + * tables. In general, star-schema joins are detected using the following conditions: + * 1. Informational RI constraints (reliable detection) + * + Dimension contains a primary key that is being joined to the fact table. + * + Fact table contains foreign keys referencing multiple dimension tables. + * 2. Cardinality based heuristics + * + Usually, the table with the highest cardinality is the fact table. + * + Table being joined with the most number of tables is the fact table. + * + * To detect star joins, the algorithm uses a combination of the above two conditions. + * The fact table is chosen based on the cardinality heuristics, and the dimension + * tables are chosen based on the RI constraints. A star join will consist of the largest + * fact table joined with the dimension tables on their primary keys. To detect that a + * column is a primary key, the algorithm uses table and column statistics. + * + * Since Catalyst only supports left-deep tree plans, the algorithm currently returns only + * the star join with the largest fact table. Choosing the largest fact table on the + * driving arm to avoid large inners is in general a good heuristic. This restriction can + * be lifted with support for bushy tree plans. + * + * The highlights of the algorithm are the following: + * + * Given a set of joined tables/plans, the algorithm first verifies if they are eligible + * for star join detection. An eligible plan is a base table access with valid statistics. + * A base table access represents Project or Filter operators above a LeafNode. Conservatively, + * the algorithm only considers base table access as part of a star join since they provide + * reliable statistics. + * + * If some of the plans are not base table access, or statistics are not available, the algorithm + * returns an empty star join plan since, in the absence of statistics, it cannot make + * good planning decisions. Otherwise, the algorithm finds the table with the largest cardinality + * (number of rows), which is assumed to be a fact table. + * + * Next, it computes the set of dimension tables for the current fact table. A dimension table + * is assumed to be in a RI relationship with a fact table. To infer column uniqueness, + * the algorithm compares the number of distinct values with the total number of rows in the + * table. If their relative difference is within certain limits (i.e. ndvMaxError * 2, adjusted + * based on 1TB TPC-DS data), the column is assumed to be unique. + */ + def findStarJoins( + input: Seq[LogicalPlan], + conditions: Seq[Expression]): Seq[Seq[LogicalPlan]] = { + + val emptyStarJoinPlan = Seq.empty[Seq[LogicalPlan]] + + if (!conf.starSchemaDetection || input.size < 2) { + emptyStarJoinPlan + } else { + // Find if the input plans are eligible for star join detection. + // An eligible plan is a base table access with valid statistics. + val foundEligibleJoin = input.forall { + case PhysicalOperation(_, _, t: LeafNode) if t.stats(conf).rowCount.isDefined => true + case _ => false + } + + if (!foundEligibleJoin) { + // Some plans don't have stats or are complex plans. Conservatively, + // return an empty star join. This restriction can be lifted + // once statistics are propagated in the plan. + emptyStarJoinPlan + } else { + // Find the fact table using cardinality based heuristics i.e. + // the table with the largest number of rows. + val sortedFactTables = input.map { plan => + TableAccessCardinality(plan, getTableAccessCardinality(plan)) + }.collect { case t @ TableAccessCardinality(_, Some(_)) => + t + }.sortBy(_.size)(implicitly[Ordering[Option[BigInt]]].reverse) + + sortedFactTables match { + case Nil => + emptyStarJoinPlan + case table1 :: table2 :: _ + if table2.size.get.toDouble > conf.starSchemaFTRatio * table1.size.get.toDouble => + // If the top largest tables have comparable number of rows, return an empty star plan. + // This restriction will be lifted when the algorithm is generalized + // to return multiple star plans. + emptyStarJoinPlan + case TableAccessCardinality(factTable, _) :: rest => + // Find the fact table joins. + val allFactJoins = rest.collect { case TableAccessCardinality(plan, _) + if findJoinConditions(factTable, plan, conditions).nonEmpty => + plan + } + + // Find the corresponding join conditions. + val allFactJoinCond = allFactJoins.flatMap { plan => + val joinCond = findJoinConditions(factTable, plan, conditions) + joinCond + } + + // Verify if the join columns have valid statistics. + // Allow any relational comparison between the tables. Later + // we will heuristically choose a subset of equi-join + // tables. + val areStatsAvailable = allFactJoins.forall { dimTable => + allFactJoinCond.exists { + case BinaryComparison(lhs: AttributeReference, rhs: AttributeReference) => + val dimCol = if (dimTable.outputSet.contains(lhs)) lhs else rhs + val factCol = if (factTable.outputSet.contains(lhs)) lhs else rhs + hasStatistics(dimCol, dimTable) && hasStatistics(factCol, factTable) + case _ => false + } + } + + if (!areStatsAvailable) { + emptyStarJoinPlan + } else { + // Find the subset of dimension tables. A dimension table is assumed to be in a + // RI relationship with the fact table. Only consider equi-joins + // between a fact and a dimension table to avoid expanding joins. + val eligibleDimPlans = allFactJoins.filter { dimTable => + allFactJoinCond.exists { + case cond @ Equality(lhs: AttributeReference, rhs: AttributeReference) => + val dimCol = if (dimTable.outputSet.contains(lhs)) lhs else rhs + isUnique(dimCol, dimTable) + case _ => false + } + } + + if (eligibleDimPlans.isEmpty) { + // An eligible star join was not found because the join is not + // an RI join, or the star join is an expanding join. + emptyStarJoinPlan + } else { + Seq(factTable +: eligibleDimPlans) + } + } + } + } + } + } + + /** + * Reorders a star join based on heuristics: + * 1) Finds the star join with the largest fact table and places it on the driving + * arm of the left-deep tree. This plan avoids large table access on the inner, and + * thus favor hash joins. + * 2) Applies the most selective dimensions early in the plan to reduce the amount of + * data flow. + */ + def reorderStarJoins( + input: Seq[(LogicalPlan, InnerLike)], + conditions: Seq[Expression]): Seq[(LogicalPlan, InnerLike)] = { + assert(input.size >= 2) + + val emptyStarJoinPlan = Seq.empty[(LogicalPlan, InnerLike)] + + // Find the eligible star plans. Currently, it only returns + // the star join with the largest fact table. + val eligibleJoins = input.collect{ case (plan, Inner) => plan } + val starPlans = findStarJoins(eligibleJoins, conditions) + + if (starPlans.isEmpty) { + emptyStarJoinPlan + } else { + val starPlan = starPlans.head + val (factTable, dimTables) = (starPlan.head, starPlan.tail) + + // Only consider selective joins. This case is detected by observing local predicates + // on the dimension tables. In a star schema relationship, the join between the fact and the + // dimension table is a FK-PK join. Heuristically, a selective dimension may reduce + // the result of a join. + // Also, conservatively assume that a fact table is joined with more than one dimension. + if (dimTables.size >= 2 && isSelectiveStarJoin(dimTables, conditions)) { + val reorderDimTables = dimTables.map { plan => + TableAccessCardinality(plan, getTableAccessCardinality(plan)) + }.sortBy(_.size).map { + case TableAccessCardinality(p1, _) => p1 + } + + val reorderStarPlan = factTable +: reorderDimTables + reorderStarPlan.map(plan => (plan, Inner)) + } else { + emptyStarJoinPlan + } + } + } + + /** + * Determines if a column referenced by a base table access is a primary key. + * A column is a PK if it is not nullable and has unique values. + * To determine if a column has unique values in the absence of informational + * RI constraints, the number of distinct values is compared to the total + * number of rows in the table. If their relative difference + * is within the expected limits (i.e. 2 * spark.sql.statistics.ndv.maxError based + * on TPCDS data results), the column is assumed to have unique values. + */ + private def isUnique( + column: Attribute, + plan: LogicalPlan): Boolean = plan match { + case PhysicalOperation(_, _, t: LeafNode) => + val leafCol = findLeafNodeCol(column, plan) + leafCol match { + case Some(col) if t.outputSet.contains(col) => + val stats = t.stats(conf) + stats.rowCount match { + case Some(rowCount) if rowCount >= 0 => + if (stats.attributeStats.nonEmpty && stats.attributeStats.contains(col)) { + val colStats = stats.attributeStats.get(col) + if (colStats.get.nullCount > 0) { + false + } else { + val distinctCount = colStats.get.distinctCount + val relDiff = math.abs((distinctCount.toDouble / rowCount.toDouble) - 1.0d) + // ndvMaxErr adjusted based on TPCDS 1TB data results + relDiff <= conf.ndvMaxError * 2 + } + } else { + false + } + case None => false + } + case None => false + } + case _ => false + } + + /** + * Given a column over a base table access, it returns + * the leaf node column from which the input column is derived. + */ + @tailrec + private def findLeafNodeCol( + column: Attribute, + plan: LogicalPlan): Option[Attribute] = plan match { + case pl @ PhysicalOperation(_, _, _: LeafNode) => + pl match { + case t: LeafNode if t.outputSet.contains(column) => + Option(column) + case p: Project if p.outputSet.exists(_.semanticEquals(column)) => + val col = p.outputSet.find(_.semanticEquals(column)).get + findLeafNodeCol(col, p.child) + case f: Filter => + findLeafNodeCol(column, f.child) + case _ => None + } + case _ => None + } + + /** + * Checks if a column has statistics. + * The column is assumed to be over a base table access. + */ + private def hasStatistics( + column: Attribute, + plan: LogicalPlan): Boolean = plan match { + case PhysicalOperation(_, _, t: LeafNode) => + val leafCol = findLeafNodeCol(column, plan) + leafCol match { + case Some(col) if t.outputSet.contains(col) => + val stats = t.stats(conf) + stats.attributeStats.nonEmpty && stats.attributeStats.contains(col) + case None => false + } + case _ => false + } + + /** + * Returns the join predicates between two input plans. It only + * considers basic comparison operators. + */ + @inline + private def findJoinConditions( + plan1: LogicalPlan, + plan2: LogicalPlan, + conditions: Seq[Expression]): Seq[Expression] = { + val refs = plan1.outputSet ++ plan2.outputSet + conditions.filter { + case BinaryComparison(_, _) => true + case _ => false + }.filterNot(canEvaluate(_, plan1)) + .filterNot(canEvaluate(_, plan2)) + .filter(_.references.subsetOf(refs)) + } + + /** + * Checks if a star join is a selective join. A star join is assumed + * to be selective if there are local predicates on the dimension + * tables. + */ + private def isSelectiveStarJoin( + dimTables: Seq[LogicalPlan], + conditions: Seq[Expression]): Boolean = dimTables.exists { + case plan @ PhysicalOperation(_, p, _: LeafNode) => + // Checks if any condition applies to the dimension tables. + // Exclude the IsNotNull predicates until predicate selectivity is available. + // In most cases, this predicate is artificially introduced by the Optimizer + // to enforce nullability constraints. + val localPredicates = conditions.filterNot(_.isInstanceOf[IsNotNull]) + .exists(canEvaluate(_, plan)) + + // Checks if there are any predicates pushed down to the base table access. + val pushedDownPredicates = p.nonEmpty && !p.forall(_.isInstanceOf[IsNotNull]) + + localPredicates || pushedDownPredicates + case _ => false + } + + /** + * Helper case class to hold (plan, rowCount) pairs. + */ + private case class TableAccessCardinality(plan: LogicalPlan, size: Option[BigInt]) + + /** + * Returns the cardinality of a base table access. A base table access represents + * a LeafNode, or Project or Filter operators above a LeafNode. + */ + private def getTableAccessCardinality( + input: LogicalPlan): Option[BigInt] = input match { + case PhysicalOperation(_, cond, t: LeafNode) if t.stats(conf).rowCount.isDefined => + if (conf.cboEnabled && input.stats(conf).rowCount.isDefined) { + Option(input.stats(conf).rowCount.get) + } else { + Option(t.stats(conf).rowCount.get) + } + case _ => None + } +} /** * Reorder the joins and push all the conditions into join, so that the bottom ones have at least * one condition. * * The order of joins will not be changed if all of them already have at least one condition. + * + * If star schema detection is enabled, reorder the star join plans based on heuristics. */ -object ReorderJoin extends Rule[LogicalPlan] with PredicateHelper { - +case class ReorderJoin(conf: SQLConf) extends Rule[LogicalPlan] with PredicateHelper { /** * Join a list of plans together and push down the conditions into them. * @@ -42,7 +370,7 @@ object ReorderJoin extends Rule[LogicalPlan] with PredicateHelper { * @param conditions a list of condition for join. */ @tailrec - def createOrderedJoin(input: Seq[(LogicalPlan, InnerLike)], conditions: Seq[Expression]) + final def createOrderedJoin(input: Seq[(LogicalPlan, InnerLike)], conditions: Seq[Expression]) : LogicalPlan = { assert(input.size >= 2) if (input.size == 2) { @@ -83,9 +411,19 @@ object ReorderJoin extends Rule[LogicalPlan] with PredicateHelper { } def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case j @ ExtractFiltersAndInnerJoins(input, conditions) + case ExtractFiltersAndInnerJoins(input, conditions) if input.size > 2 && conditions.nonEmpty => - createOrderedJoin(input, conditions) + if (conf.starSchemaDetection && !conf.cboEnabled) { + val starJoinPlan = StarSchemaDetection(conf).reorderStarJoins(input, conditions) + if (starJoinPlan.nonEmpty) { + val rest = input.filterNot(starJoinPlan.contains(_)) + createOrderedJoin(starJoinPlan ++ rest, conditions) + } else { + createOrderedJoin(input, conditions) + } + } else { + createOrderedJoin(input, conditions) + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 0893af26738b..d39b0ef7e1d8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -167,8 +167,8 @@ object ExtractFiltersAndInnerJoins extends PredicateHelper { : (Seq[(LogicalPlan, InnerLike)], Seq[Expression]) = plan match { case Join(left, right, joinType: InnerLike, cond) => val (plans, conditions) = flattenJoin(left, joinType) - (plans ++ Seq((right, joinType)), conditions ++ cond.toSeq) - + (plans ++ Seq((right, joinType)), conditions ++ + cond.toSeq.flatMap(splitConjunctivePredicates)) case Filter(filterCondition, j @ Join(left, right, _: InnerLike, joinCondition)) => val (plans, conditions) = flattenJoin(j) (plans, conditions ++ splitConjunctivePredicates(filterCondition)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index d2ac4b88ee8f..b6e0b8ccbeed 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -719,6 +719,18 @@ object SQLConf { .checkValue(weight => weight >= 0 && weight <= 1, "The weight value must be in [0, 1].") .createWithDefault(0.7) + val STARSCHEMA_DETECTION = buildConf("spark.sql.cbo.starSchemaDetection") + .doc("When true, it enables join reordering based on star schema detection. ") + .booleanConf + .createWithDefault(false) + + val STARSCHEMA_FACT_TABLE_RATIO = buildConf("spark.sql.cbo.starJoinFTRatio") + .internal() + .doc("Specifies the upper limit of the ratio between the largest fact tables" + + " for a star join to be considered. ") + .doubleConf + .createWithDefault(0.9) + val SESSION_LOCAL_TIMEZONE = buildConf("spark.sql.session.timeZone") .doc("""The ID of session local timezone, e.g. "GMT", "America/Los_Angeles", etc.""") @@ -988,6 +1000,10 @@ class SQLConf extends Serializable with Logging { def maxNestedViewDepth: Int = getConf(SQLConf.MAX_NESTED_VIEW_DEPTH) + def starSchemaDetection: Boolean = getConf(STARSCHEMA_DETECTION) + + def starSchemaFTRatio: Double = getConf(STARSCHEMA_FACT_TABLE_RATIO) + /** ********************** SQLConf functionality methods ************ */ /** Set Spark SQL configuration properties. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala index 985e49069da9..61e81808147c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.planning.ExtractFiltersAndInnerJoins import org.apache.spark.sql.catalyst.plans.{Cross, Inner, InnerLike, PlanTest} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.RuleExecutor - +import org.apache.spark.sql.catalyst.SimpleCatalystConf class JoinOptimizationSuite extends PlanTest { @@ -38,7 +38,7 @@ class JoinOptimizationSuite extends PlanTest { CombineFilters, PushDownPredicate, BooleanSimplification, - ReorderJoin, + ReorderJoin(SimpleCatalystConf(true)), PushPredicateThroughJoin, ColumnPruning, CollapseProject) :: Nil diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala index 5607bcd16f3f..05b839b0119f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala @@ -22,10 +22,9 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap} import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest} -import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Join, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.catalyst.statsEstimation.{StatsEstimationTestBase, StatsTestPlan} -import org.apache.spark.sql.catalyst.util._ class JoinReorderSuite extends PlanTest with StatsEstimationTestBase { @@ -38,7 +37,7 @@ class JoinReorderSuite extends PlanTest with StatsEstimationTestBase { Batch("Operator Optimizations", FixedPoint(100), CombineFilters, PushDownPredicate, - ReorderJoin, + ReorderJoin(conf), PushPredicateThroughJoin, ColumnPruning, CollapseProject) :: @@ -203,27 +202,7 @@ class JoinReorderSuite extends PlanTest with StatsEstimationTestBase { originalPlan: LogicalPlan, groundTruthBestPlan: LogicalPlan): Unit = { val optimized = Optimize.execute(originalPlan.analyze) - val normalized1 = normalizePlan(normalizeExprIds(optimized)) - val normalized2 = normalizePlan(normalizeExprIds(groundTruthBestPlan.analyze)) - if (!sameJoinPlan(normalized1, normalized2)) { - fail( - s""" - |== FAIL: Plans do not match === - |${sideBySide(normalized1.treeString, normalized2.treeString).mkString("\n")} - """.stripMargin) - } - } - - /** Consider symmetry for joins when comparing plans. */ - private def sameJoinPlan(plan1: LogicalPlan, plan2: LogicalPlan): Boolean = { - (plan1, plan2) match { - case (j1: Join, j2: Join) => - (sameJoinPlan(j1.left, j2.left) && sameJoinPlan(j1.right, j2.right)) || - (sameJoinPlan(j1.left, j2.right) && sameJoinPlan(j1.right, j2.left)) - case _ if plan1.children.nonEmpty && plan2.children.nonEmpty => - (plan1.children, plan2.children).zipped.forall { case (c1, c2) => sameJoinPlan(c1, c2) } - case _ => - plan1 == plan2 - } + val expected = groundTruthBestPlan.analyze + compareJoinOrder(optimized, expected) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinReorderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinReorderSuite.scala new file mode 100644 index 000000000000..93fdd98d1ac9 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinReorderSuite.scala @@ -0,0 +1,580 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.SimpleCatalystConf +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap} +import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest} +import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.catalyst.statsEstimation.{StatsEstimationTestBase, StatsTestPlan} + + +class StarJoinReorderSuite extends PlanTest with StatsEstimationTestBase { + + override val conf = SimpleCatalystConf( + caseSensitiveAnalysis = true, starSchemaDetection = true) + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("Operator Optimizations", FixedPoint(100), + CombineFilters, + PushDownPredicate, + ReorderJoin(conf), + PushPredicateThroughJoin, + ColumnPruning, + CollapseProject) :: Nil + } + + // Table setup using star schema relationships: + // + // d1 - f1 - d2 + // | + // d3 - s3 + // + // Table f1 is the fact table. Tables d1, d2, and d3 are the dimension tables. + // Dimension d3 is further joined/normalized into table s3. + // Tables' cardinality: f1 > d3 > d1 > d2 > s3 + private val columnInfo: AttributeMap[ColumnStat] = AttributeMap(Seq( + // F1 + attr("f1_fk1") -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3), + nullCount = 0, avgLen = 4, maxLen = 4), + attr("f1_fk2") -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3), + nullCount = 0, avgLen = 4, maxLen = 4), + attr("f1_fk3") -> ColumnStat(distinctCount = 4, min = Some(1), max = Some(4), + nullCount = 0, avgLen = 4, maxLen = 4), + attr("f1_c4") -> ColumnStat(distinctCount = 4, min = Some(1), max = Some(4), + nullCount = 0, avgLen = 4, maxLen = 4), + // D1 + attr("d1_pk1") -> ColumnStat(distinctCount = 4, min = Some(1), max = Some(4), + nullCount = 0, avgLen = 4, maxLen = 4), + attr("d1_c2") -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3), + nullCount = 0, avgLen = 4, maxLen = 4), + attr("d1_c3") -> ColumnStat(distinctCount = 4, min = Some(1), max = Some(4), + nullCount = 0, avgLen = 4, maxLen = 4), + attr("d1_c4") -> ColumnStat(distinctCount = 2, min = Some(2), max = Some(3), + nullCount = 0, avgLen = 4, maxLen = 4), + // D2 + attr("d2_c2") -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3), + nullCount = 1, avgLen = 4, maxLen = 4), + attr("d2_pk1") -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3), + nullCount = 0, avgLen = 4, maxLen = 4), + attr("d2_c3") -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3), + nullCount = 0, avgLen = 4, maxLen = 4), + attr("d2_c4") -> ColumnStat(distinctCount = 2, min = Some(3), max = Some(4), + nullCount = 0, avgLen = 4, maxLen = 4), + // D3 + attr("d3_fk1") -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3), + nullCount = 0, avgLen = 4, maxLen = 4), + attr("d3_c2") -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3), + nullCount = 0, avgLen = 4, maxLen = 4), + attr("d3_pk1") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), + nullCount = 0, avgLen = 4, maxLen = 4), + attr("d3_c4") -> ColumnStat(distinctCount = 2, min = Some(2), max = Some(3), + nullCount = 0, avgLen = 4, maxLen = 4), + // S3 + attr("s3_pk1") -> ColumnStat(distinctCount = 2, min = Some(1), max = Some(2), + nullCount = 0, avgLen = 4, maxLen = 4), + attr("s3_c2") -> ColumnStat(distinctCount = 1, min = Some(3), max = Some(3), + nullCount = 0, avgLen = 4, maxLen = 4), + attr("s3_c3") -> ColumnStat(distinctCount = 1, min = Some(3), max = Some(3), + nullCount = 0, avgLen = 4, maxLen = 4), + attr("s3_c4") -> ColumnStat(distinctCount = 2, min = Some(3), max = Some(4), + nullCount = 0, avgLen = 4, maxLen = 4), + // F11 + attr("f11_fk1") -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3), + nullCount = 0, avgLen = 4, maxLen = 4), + attr("f11_fk2") -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3), + nullCount = 0, avgLen = 4, maxLen = 4), + attr("f11_fk3") -> ColumnStat(distinctCount = 4, min = Some(1), max = Some(4), + nullCount = 0, avgLen = 4, maxLen = 4), + attr("f11_c4") -> ColumnStat(distinctCount = 4, min = Some(1), max = Some(4), + nullCount = 0, avgLen = 4, maxLen = 4) + )) + + private val nameToAttr: Map[String, Attribute] = columnInfo.map(kv => kv._1.name -> kv._1) + private val nameToColInfo: Map[String, (Attribute, ColumnStat)] = + columnInfo.map(kv => kv._1.name -> kv) + + private val f1 = StatsTestPlan( + outputList = Seq("f1_fk1", "f1_fk2", "f1_fk3", "f1_c4").map(nameToAttr), + rowCount = 6, + size = Some(48), + attributeStats = AttributeMap(Seq("f1_fk1", "f1_fk2", "f1_fk3", "f1_c4").map(nameToColInfo))) + + private val d1 = StatsTestPlan( + outputList = Seq("d1_pk1", "d1_c2", "d1_c3", "d1_c4").map(nameToAttr), + rowCount = 4, + size = Some(32), + attributeStats = AttributeMap(Seq("d1_pk1", "d1_c2", "d1_c3", "d1_c4").map(nameToColInfo))) + + private val d2 = StatsTestPlan( + outputList = Seq("d2_c2", "d2_pk1", "d2_c3", "d2_c4").map(nameToAttr), + rowCount = 3, + size = Some(24), + attributeStats = AttributeMap(Seq("d2_c2", "d2_pk1", "d2_c3", "d2_c4").map(nameToColInfo))) + + private val d3 = StatsTestPlan( + outputList = Seq("d3_fk1", "d3_c2", "d3_pk1", "d3_c4").map(nameToAttr), + rowCount = 5, + size = Some(40), + attributeStats = AttributeMap(Seq("d3_fk1", "d3_c2", "d3_pk1", "d3_c4").map(nameToColInfo))) + + private val s3 = StatsTestPlan( + outputList = Seq("s3_pk1", "s3_c2", "s3_c3", "s3_c4").map(nameToAttr), + rowCount = 2, + size = Some(17), + attributeStats = AttributeMap(Seq("s3_pk1", "s3_c2", "s3_c3", "s3_c4").map(nameToColInfo))) + + private val d3_ns = LocalRelation('d3_fk1.int, 'd3_c2.int, 'd3_pk1.int, 'd3_c4.int) + + private val f11 = StatsTestPlan( + outputList = Seq("f11_fk1", "f11_fk2", "f11_fk3", "f11_c4").map(nameToAttr), + rowCount = 6, + size = Some(48), + attributeStats = AttributeMap(Seq("f11_fk1", "f11_fk2", "f11_fk3", "f11_c4") + .map(nameToColInfo))) + + private val subq = d3.select(sum('d3_fk1).as('col)) + + test("Test 1: Selective star-join on all dimensions") { + // Star join: + // (=) (=) + // d1 - f1 - d2 + // | (=) + // s3 - d3 + // + // Query: + // select f1_fk1, f1_fk3 + // from d1, d2, f1, d3, s3 + // where f1_fk2 = d2_pk1 and d2_c2 < 2 + // and f1_fk1 = d1_pk1 + // and f1_fk3 = d3_pk1 + // and d3_fk1 = s3_pk1 + // + // Positional join reordering: d1, f1, d2, d3, s3 + // Star join reordering: f1, d2, d1, d3, s3 + val query = + d1.join(d2).join(f1).join(d3).join(s3) + .where((nameToAttr("f1_fk2") === nameToAttr("d2_pk1")) && + (nameToAttr("d2_c2") === 2) && + (nameToAttr("f1_fk1") === nameToAttr("d1_pk1")) && + (nameToAttr("f1_fk3") === nameToAttr("d3_pk1")) && + (nameToAttr("d3_fk1") === nameToAttr("s3_pk1"))) + + val expected = + f1.join(d2.where(nameToAttr("d2_c2") === 2), Inner, + Some(nameToAttr("f1_fk2") === nameToAttr("d2_pk1"))) + .join(d1, Inner, Some(nameToAttr("f1_fk1") === nameToAttr("d1_pk1"))) + .join(d3, Inner, Some(nameToAttr("f1_fk3") === nameToAttr("d3_pk1"))) + .join(s3, Inner, Some(nameToAttr("d3_fk1") === nameToAttr("s3_pk1"))) + + assertEqualPlans(query, expected) + } + + test("Test 2: Star join on a subset of dimensions due to inequality joins") { + // Star join: + // (=) (<) + // d1 - f1 - d2 + // | + // | (=) + // d3 - s3 + // (=) + // + // Query: + // select f1_fk1, f1_fk3 + // from d1, f1, d2, s3, d3 + // where f1_fk2 < d2_pk1 + // and f1_fk1 = d1_pk1 and d1_c2 = 2 + // and f1_fk3 = d3_pk1 + // and d3_fk1 = s3_pk1 + // + // Default join reordering: d1, f1, d2, d3, s3 + // Star join reordering: f1, d1, d3, d2,, d3 + + val query = + d1.join(f1).join(d2).join(s3).join(d3) + .where((nameToAttr("f1_fk2") < nameToAttr("d2_pk1")) && + (nameToAttr("f1_fk1") === nameToAttr("d1_pk1")) && + (nameToAttr("d1_c2") === 2) && + (nameToAttr("f1_fk3") === nameToAttr("d3_pk1")) && + (nameToAttr("d3_fk1") === nameToAttr("s3_pk1"))) + + val expected = + f1.join(d1.where(nameToAttr("d1_c2") === 2), Inner, + Some(nameToAttr("f1_fk1") === nameToAttr("d1_pk1"))) + .join(d3, Inner, Some(nameToAttr("f1_fk3") === nameToAttr("d3_pk1"))) + .join(d2, Inner, Some(nameToAttr("f1_fk2") < nameToAttr("d2_pk1"))) + .join(s3, Inner, Some(nameToAttr("d3_fk1") === nameToAttr("s3_pk1"))) + + assertEqualPlans(query, expected) + } + + test("Test 3: Star join on a subset of dimensions since join column is not unique") { + // Star join: + // (=) (=) + // d1 - f1 - d2 + // | (=) + // d3 - s3 + // + // Query: + // select f1_fk1, f1_fk3 + // from d1, f1, d2, s3, d3 + // where f1_fk2 = d2_c4 + // and f1_fk1 = d1_pk1 and d1_c2 = 2 + // and f1_fk3 = d3_pk1 + // and d3_fk1 = s3_pk1 + // + // Default join reordering: d1, f1, d2, d3, s3 + // Star join reordering: f1, d1, d3, d2, d3 + val query = + d1.join(f1).join(d2).join(s3).join(d3) + .where((nameToAttr("f1_fk1") === nameToAttr("d1_pk1")) && + (nameToAttr("d1_c2") === 2) && + (nameToAttr("f1_fk2") === nameToAttr("d2_c4")) && + (nameToAttr("f1_fk3") === nameToAttr("d3_pk1")) && + (nameToAttr("d3_fk1") === nameToAttr("s3_pk1"))) + + val expected = + f1.join(d1.where(nameToAttr("d1_c2") === 2), Inner, + Some(nameToAttr("f1_fk1") === nameToAttr("d1_pk1"))) + .join(d3, Inner, Some(nameToAttr("d3_fk1") === nameToAttr("s3_pk1"))) + .join(d2, Inner, Some(nameToAttr("f1_fk2") === nameToAttr("d2_pk1"))) + .join(s3, Inner, Some(nameToAttr("f1_fk3") === nameToAttr("s3_c2"))) + + + assertEqualPlans(query, expected) + } + + test("Test 4: Star join on a subset of dimensions since join column is nullable") { + // Star join: + // (=) (=) + // d1 - f1 - d2 + // | (=) + // s3 - d3 + // + // Query: + // select f1_fk1, f1_fk3 + // from d1, f1, d2, s3, d3 + // where f1_fk2 = d2_c2 + // and f1_fk1 = d1_pk1 and d1_c2 = 2 + // and f1_fk3 = d3_pk1 + // and d3_fk1 = s3_pk1 + // + // Default join reordering: d1, f1, d2, d3, s3 + // Star join reordering: f1, d1, d3, d2, s3 + + val query = + d1.join(f1).join(d2).join(s3).join(d3) + .where((nameToAttr("f1_fk1") === nameToAttr("d1_pk1")) && + (nameToAttr("d1_c2") === 2) && + (nameToAttr("f1_fk2") === nameToAttr("d2_c2")) && + (nameToAttr("f1_fk3") === nameToAttr("d3_pk1")) && + (nameToAttr("d3_fk1") === nameToAttr("s3_pk1"))) + + val expected = + f1.join(d1.where(nameToAttr("d1_c2") === 2), Inner, + Some(nameToAttr("f1_fk1") === nameToAttr("d1_pk1"))) + .join(d3, Inner, Some(nameToAttr("f1_fk3") === nameToAttr("d3_pk1"))) + .join(d2, Inner, Some(nameToAttr("f1_fk2") === nameToAttr("d2_c2"))) + .join(s3, Inner, Some(nameToAttr("d3_fk1") < nameToAttr("s3_pk1"))) + + assertEqualPlans(query, expected) + } + + test("Test 5: Table stats not available for some of the joined tables") { + // Star join: + // (=) (=) + // d1 - f1 - d2 + // | (=) + // d3_ns - s3 + // + // select f1_fk1, f1_fk3 + // from d3_ns, f1, d1, d2, s3 + // where f1_fk2 = d2_pk1 and d2_c2 = 2 + // and f1_fk1 = d1_pk1 + // and f1_fk3 = d3_pk1 + // and d3_fk1 = s3_pk1 + // + // Positional join reordering: d3_ns, f1, d1, d2, s3 + // Star join reordering: empty + + val query = + d3_ns.join(f1).join(d1).join(d2).join(s3) + .where((nameToAttr("f1_fk2") === nameToAttr("d2_pk1")) && + (nameToAttr("d2_c2") === 2) && + (nameToAttr("f1_fk1") === nameToAttr("d1_pk1")) && + (nameToAttr("f1_fk3") === nameToAttr("d3_pk1")) && + (nameToAttr("d3_fk1") === nameToAttr("s3_pk1"))) + + val equivQuery = + d3_ns.join(f1, Inner, Some(nameToAttr("f1_fk3") === nameToAttr("d3_pk1"))) + .join(d1, Inner, Some(nameToAttr("f1_fk1") === nameToAttr("d1_pk1"))) + .join(d2.where(nameToAttr("d2_c2") === 2), Inner, + Some(nameToAttr("f1_fk2") === nameToAttr("d2_pk1"))) + .join(s3, Inner, Some(nameToAttr("d3_fk1") === nameToAttr("s3_pk1"))) + + assertEqualPlans(query, equivQuery) + } + + test("Test 6: Join with complex plans") { + // Star join: + // (=) (=) + // d1 - f1 - d2 + // | (=) + // (sub-query) + // + // select f1_fk1, f1_fk3 + // from (select sum(d3_fk1) as col from d3) subq, f1, d1, d2 + // where f1_fk2 = d2_pk1 and d2_c2 < 2 + // and f1_fk1 = d1_pk1 + // and f1_fk3 = sq.col + // + // Positional join reordering: d3, f1, d1, d2 + // Star join reordering: empty + + val query = + subq.join(f1).join(d1).join(d2) + .where((nameToAttr("f1_fk2") === nameToAttr("d2_pk1")) && + (nameToAttr("d2_c2") === 2) && + (nameToAttr("f1_fk1") === nameToAttr("d1_pk1")) && + (nameToAttr("f1_fk3") === "col".attr)) + + val expected = + d3.select('d3_fk1).select(sum('d3_fk1).as('col)) + .join(f1, Inner, Some(nameToAttr("f1_fk3") === "col".attr)) + .join(d1, Inner, Some(nameToAttr("f1_fk1") === nameToAttr("d1_pk1"))) + .join(d2.where(nameToAttr("d2_c2") === 2), Inner, + Some(nameToAttr("f1_fk2") === nameToAttr("d2_pk1"))) + + assertEqualPlans(query, expected) + } + + test("Test 7: Comparable fact table sizes") { + // Star join: + // (=) (=) + // d1 - f1 - d2 + // | (=) + // f11 - s3 + // + // select f1.f1_fk1, f1.f1_fk3 + // from d1, f11, f1, d2, s3 + // where f1.f1_fk2 = d2_pk1 and d2_c2 = 2 + // and f1.f1_fk1 = d1_pk1 + // and f1.f1_fk3 = f11.f1_fk3 + // and f11.f1_fk1 = s3_pk1 + // + // Positional join reordering: d1, f1, f11, d2, s3 + // Star join reordering: empty + + val query = + d1.join(f11).join(f1).join(d2).join(s3) + .where((nameToAttr("f1_fk2") === nameToAttr("d2_pk1")) && + (nameToAttr("d2_c2") === 2) && + (nameToAttr("f1_fk1") === nameToAttr("d1_pk1")) && + (nameToAttr("f1_fk3") === nameToAttr("f11_fk3")) && + (nameToAttr("f11_fk1") === nameToAttr("s3_pk1"))) + + val equivQuery = + d1.join(f1, Inner, Some(nameToAttr("f1_fk1") === nameToAttr("d1_pk1"))) + .join(f11, Inner, Some(nameToAttr("f1_fk3") === nameToAttr("f11_fk3"))) + .join(d2.where(nameToAttr("d2_c2") === 2), Inner, + Some(nameToAttr("f1_fk2") === nameToAttr("d2_pk1"))) + .join(s3, Inner, Some(nameToAttr("f11_fk1") === nameToAttr("s3_pk1"))) + + assertEqualPlans(query, equivQuery) + } + + test("Test 8: No RI joins") { + // Star join: + // (=) (=) + // d1 - f1 - d2 + // | (=) + // d3 - s3 + // + // select f1_fk1, f1_fk3 + // from d1, d3, f1, d2, s3 + // where f1_fk2 = d2_c4 and d2_c2 = 2 + // and f1_fk1 = d1_c4 + // and f1_fk3 = d3_c4 + // and d3_fk1 = s3_pk1 + // + // Positional/default join reordering: d1, f1, d3, d2, s3 + // Star join reordering: empty + + val query = + d1.join(d3).join(f1).join(d2).join(s3) + .where((nameToAttr("f1_fk2") === nameToAttr("d2_c4")) && + (nameToAttr("d2_c2") === 2) && + (nameToAttr("f1_fk1") === nameToAttr("d1_c4")) && + (nameToAttr("f1_fk3") === nameToAttr("d3_c4")) && + (nameToAttr("d3_fk1") === nameToAttr("s3_pk1"))) + + val expected = + d1.join(f1, Inner, Some(nameToAttr("f1_fk1") === nameToAttr("d1_c4"))) + .join(d3, Inner, Some(nameToAttr("f1_fk3") === nameToAttr("d3_c4"))) + .join(d2.where(nameToAttr("d2_c2") === 2), Inner, + Some(nameToAttr("f1_fk2") === nameToAttr("d2_c4"))) + .join(s3, Inner, Some(nameToAttr("d3_fk1") === nameToAttr("s3_pk1"))) + + assertEqualPlans(query, expected) + } + + test("Test 9: Complex join predicates") { + // Star join: + // (=) (=) + // d1 - f1 - d2 + // | (=) + // d3 - s3 + // + // select f1_fk1, f1_fk3 + // from d1, d3, f1, d2, s3 + // where f1_fk2 = d2_pk1 and d2_c2 = 2 + // and abs(f1_fk1) = d1_pk1 + // and f1_fk3 = d3_pk1 + // and d3_fk1 = s3_pk1 + // + // Positional/default join reordering: d1, f1, d3, d2, s3 + // Star join reordering: empty + + val query = + d1.join(d3).join(f1).join(d2).join(s3) + .where((nameToAttr("f1_fk2") === nameToAttr("d2_pk1")) && + (nameToAttr("d2_c2") === 2) && + (abs(nameToAttr("f1_fk1")) === nameToAttr("d1_pk1")) && + (nameToAttr("f1_fk3") === nameToAttr("d3_pk1")) && + (nameToAttr("d3_fk1") === nameToAttr("s3_pk1"))) + + val expected = + d1.join(f1, Inner, Some(abs(nameToAttr("f1_fk1")) === nameToAttr("d1_pk1"))) + .join(d3, Inner, Some(nameToAttr("f1_fk3") === nameToAttr("d3_pk1"))) + .join(d2.where(nameToAttr("d2_c2") === 2), Inner, + Some(nameToAttr("f1_fk2") === nameToAttr("d2_pk1"))) + .join(s3, Inner, Some(nameToAttr("d3_fk1") === nameToAttr("s3_pk1"))) + + assertEqualPlans(query, expected) + } + + test("Test 10: Less than two dimensions") { + // Star join: + // (<) (=) + // d1 - f1 - d2 + // |(<) + // d3 - s3 + // + // select f1_fk1, f1_fk3 + // from d1, d3, f1, d2, s3 + // where f1_fk2 = d2_pk1 and d2_c2 = 2 + // and f1_fk1 < d1_pk1 + // and f1_fk3 < d3_pk1 + // + // Positional join reordering: d1, f1, d3, d2, s3 + // Star join reordering: empty + + val query = + d1.join(d3).join(f1).join(d2).join(s3) + .where((nameToAttr("f1_fk2") === nameToAttr("d2_pk1")) && + (nameToAttr("d2_c2") === 2) && + (nameToAttr("f1_fk1") < nameToAttr("d1_pk1")) && + (nameToAttr("f1_fk3") < nameToAttr("d3_pk1")) && + (nameToAttr("d3_fk1") === nameToAttr("s3_pk1"))) + + val expected = + d1.join(f1, Inner, Some(nameToAttr("f1_fk1") < nameToAttr("d1_pk1"))) + .join(d3, Inner, Some(nameToAttr("f1_fk3") < nameToAttr("d3_pk1"))) + .join(d2.where(nameToAttr("d2_c2") === 2), + Inner, Some(nameToAttr("f1_fk2") === nameToAttr("d2_pk1"))) + .join(s3, Inner, Some(nameToAttr("d3_fk1") === nameToAttr("s3_pk1"))) + + assertEqualPlans(query, expected) + } + + test("Test 11: Expanding star join") { + // Star join: + // (<) (<) + // d1 - f1 - d2 + // | (<) + // d3 - s3 + // + // select f1_fk1, f1_fk3 + // from d1, d3, f1, d2, s3 + // where f1_fk2 < d2_pk1 + // and f1_fk1 < d1_pk1 + // and f1_fk3 < d3_pk1 + // and d3_fk1 < s3_pk1 + // + // Positional join reordering: d1, f1, d3, d2, s3 + // Star join reordering: empty + + val query = + d1.join(d3).join(f1).join(d2).join(s3) + .where((nameToAttr("f1_fk2") < nameToAttr("d2_pk1")) && + (nameToAttr("f1_fk1") < nameToAttr("d1_pk1")) && + (nameToAttr("f1_fk3") < nameToAttr("d3_pk1")) && + (nameToAttr("d3_fk1") < nameToAttr("s3_pk1"))) + + val expected = + d1.join(f1, Inner, Some(nameToAttr("f1_fk1") < nameToAttr("d1_pk1"))) + .join(d3, Inner, Some(nameToAttr("f1_fk3") < nameToAttr("d3_pk1"))) + .join(d2, Inner, Some(nameToAttr("f1_fk2") < nameToAttr("d2_pk1"))) + .join(s3, Inner, Some(nameToAttr("d3_fk1") < nameToAttr("s3_pk1"))) + + assertEqualPlans(query, expected) + } + + test("Test 12: Non selective star join") { + // Star join: + // (=) (=) + // d1 - f1 - d2 + // | (=) + // d3 - s3 + // + // select f1_fk1, f1_fk3 + // from d1, d3, f1, d2, s3 + // where f1_fk2 = d2_pk1 + // and f1_fk1 = d1_pk1 + // and f1_fk3 = d3_pk1 + // and d3_fk1 = s3_pk1 + // + // Positional join reordering: d1, f1, d3, d2, s3 + // Star join reordering: empty + + val query = + d1.join(d3).join(f1).join(d2).join(s3) + .where((nameToAttr("f1_fk2") === nameToAttr("d2_pk1")) && + (nameToAttr("f1_fk1") === nameToAttr("d1_pk1")) && + (nameToAttr("f1_fk3") === nameToAttr("d3_pk1")) && + (nameToAttr("d3_fk1") === nameToAttr("s3_pk1"))) + + val expected = + d1.join(f1, Inner, Some(nameToAttr("f1_fk1") === nameToAttr("d1_pk1"))) + .join(d3, Inner, Some(nameToAttr("f1_fk3") === nameToAttr("d3_pk1"))) + .join(d2, Inner, Some(nameToAttr("f1_fk2") === nameToAttr("d2_pk1"))) + .join(s3, Inner, Some(nameToAttr("d3_fk1") === nameToAttr("s3_pk1"))) + + assertEqualPlans(query, expected) + } + + private def assertEqualPlans( plan1: LogicalPlan, plan2: LogicalPlan): Unit = { + val optimized = Optimize.execute(plan1.analyze) + val expected = plan2.analyze + compareJoinOrder(optimized, expected) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala index 5eb31413ad70..2a9d0570148a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala @@ -106,4 +106,30 @@ abstract class PlanTest extends SparkFunSuite with PredicateHelper { protected def compareExpressions(e1: Expression, e2: Expression): Unit = { comparePlans(Filter(e1, OneRowRelation), Filter(e2, OneRowRelation)) } + + /** Fails the test if the join order in the two plans do not match */ + protected def compareJoinOrder(plan1: LogicalPlan, plan2: LogicalPlan) { + val normalized1 = normalizePlan(normalizeExprIds(plan1)) + val normalized2 = normalizePlan(normalizeExprIds(plan2)) + if (!sameJoinPlan(normalized1, normalized2)) { + fail( + s""" + |== FAIL: Plans do not match === + |${sideBySide(normalized1.treeString, normalized2.treeString).mkString("\n")} + """.stripMargin) + } + } + + /** Consider symmetry for joins when comparing plans. */ + private def sameJoinPlan(plan1: LogicalPlan, plan2: LogicalPlan): Boolean = { + (plan1, plan2) match { + case (j1: Join, j2: Join) => + (sameJoinPlan(j1.left, j2.left) && sameJoinPlan(j1.right, j2.right)) || + (sameJoinPlan(j1.left, j2.right) && sameJoinPlan(j1.right, j2.left)) + case _ if plan1.children.nonEmpty && plan2.children.nonEmpty => + (plan1.children, plan2.children).zipped.forall { case (c1, c2) => sameJoinPlan(c1, c2) } + case _ => + plan1 == plan2 + } + } } From 7ce30e00b236e77b5175f797f9c6fc6cf4ca7e93 Mon Sep 17 00:00:00 2001 From: windpiger Date: Mon, 20 Mar 2017 21:36:00 +0800 Subject: [PATCH 076/512] [SPARK-19990][SQL][TEST-MAVEN] create a temp file for file in test.jar's resource when run mvn test accross different modules ## What changes were proposed in this pull request? After we have merged the `HiveDDLSuite` and `DDLSuite` in [SPARK-19235](https://issues.apache.org/jira/browse/SPARK-19235), we have two subclasses of `DDLSuite`, that is `HiveCatalogedDDLSuite` and `InMemoryCatalogDDLSuite`. While `DDLSuite` is in `sql/core module`, and `HiveCatalogedDDLSuite` is in `sql/hive module`, if we mvn test `HiveCatalogedDDLSuite`, it will run the test in its parent class `DDLSuite`, this will cause some test case failed which will get and use the test file path in `sql/core module` 's `resource`. Because the test file path getted will start with 'jar:' like "jar:file:/home/jenkins/workspace/spark-master-test-maven-hadoop-2.6/sql/core/target/spark-sql_2.11-2.2.0-SNAPSHOT-tests.jar!/test-data/cars.csv", which will failed when new Path() in datasource.scala This PR fix this by copy file from resource to a temp dir. ## How was this patch tested? N/A Author: windpiger Closes #17338 from windpiger/fixtestfailemvn. --- .../sql/execution/command/DDLSuite.scala | 33 +++++++++++-------- .../apache/spark/sql/test/SQLTestUtils.scala | 17 +++++++++- 2 files changed, 36 insertions(+), 14 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index dd76fdde06cd..235c6bf6ad59 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -25,7 +25,7 @@ import org.scalatest.BeforeAndAfterEach import org.apache.spark.sql.{AnalysisException, QueryTest, Row, SaveMode} import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.analysis.{DatabaseAlreadyExistsException, FunctionRegistry, NoSuchPartitionException, NoSuchTableException, TempTableAlreadyExistsException} +import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, NoSuchPartitionException, NoSuchTableException, TempTableAlreadyExistsException} import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.internal.SQLConf @@ -699,21 +699,28 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } test("create temporary view using") { - val csvFile = - Thread.currentThread().getContextClassLoader.getResource("test-data/cars.csv").toString - withView("testview") { - sql(s"CREATE OR REPLACE TEMPORARY VIEW testview (c1 String, c2 String) USING " + - "org.apache.spark.sql.execution.datasources.csv.CSVFileFormat " + - s"OPTIONS (PATH '$csvFile')") + // when we test the HiveCatalogedDDLSuite, it will failed because the csvFile path above + // starts with 'jar:', and it is an illegal parameter for Path, so here we copy it + // to a temp file by withResourceTempPath + withResourceTempPath("test-data/cars.csv") { tmpFile => + withView("testview") { + sql(s"CREATE OR REPLACE TEMPORARY VIEW testview (c1 String, c2 String) USING " + + "org.apache.spark.sql.execution.datasources.csv.CSVFileFormat " + + s"OPTIONS (PATH '$tmpFile')") - checkAnswer( - sql("select c1, c2 from testview order by c1 limit 1"), + checkAnswer( + sql("select c1, c2 from testview order by c1 limit 1"), Row("1997", "Ford") :: Nil) - // Fails if creating a new view with the same name - intercept[TempTableAlreadyExistsException] { - sql(s"CREATE TEMPORARY VIEW testview USING " + - s"org.apache.spark.sql.execution.datasources.csv.CSVFileFormat OPTIONS (PATH '$csvFile')") + // Fails if creating a new view with the same name + intercept[TempTableAlreadyExistsException] { + sql( + s""" + |CREATE TEMPORARY VIEW testview + |USING org.apache.spark.sql.execution.datasources.csv.CSVFileFormat + |OPTIONS (PATH '$tmpFile') + """.stripMargin) + } } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index 9201954b66d1..cab219216d1c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -19,10 +19,10 @@ package org.apache.spark.sql.test import java.io.File import java.net.URI +import java.nio.file.Files import java.util.UUID import scala.language.implicitConversions -import scala.util.Try import scala.util.control.NonFatal import org.apache.hadoop.fs.Path @@ -123,6 +123,21 @@ private[sql] trait SQLTestUtils try f(path) finally Utils.deleteRecursively(path) } + /** + * Copy file in jar's resource to a temp file, then pass it to `f`. + * This function is used to make `f` can use the path of temp file(e.g. file:/), instead of + * path of jar's resource which starts with 'jar:file:/' + */ + protected def withResourceTempPath(resourcePath: String)(f: File => Unit): Unit = { + val inputStream = + Thread.currentThread().getContextClassLoader.getResourceAsStream(resourcePath) + withTempDir { dir => + val tmpFile = new File(dir, "tmp") + Files.copy(inputStream, tmpFile.toPath) + f(tmpFile) + } + } + /** * Creates a temporary directory, which is then passed to `f` and will be deleted after `f` * returns. From fc7554599a4b6e5c22aa35e7296b424a653a420b Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Mon, 20 Mar 2017 10:07:31 -0700 Subject: [PATCH 077/512] [SPARK-19970][SQL] Table owner should be USER instead of PRINCIPAL in kerberized clusters ## What changes were proposed in this pull request? In the kerberized hadoop cluster, when Spark creates tables, the owner of tables are filled with PRINCIPAL strings instead of USER names. This is inconsistent with Hive and causes problems when using [ROLE](https://cwiki.apache.org/confluence/display/Hive/SQL+Standard+Based+Hive+Authorization) in Hive. We had better to fix this. **BEFORE** ```scala scala> sql("create table t(a int)").show scala> sql("desc formatted t").show(false) ... |Owner: |sparkEXAMPLE.COM | | ``` **AFTER** ```scala scala> sql("create table t(a int)").show scala> sql("desc formatted t").show(false) ... |Owner: |spark | | ``` ## How was this patch tested? Manually do `create table` and `desc formatted` because this happens in Kerberized clusters. Author: Dongjoon Hyun Closes #17311 from dongjoon-hyun/SPARK-19970. --- .../scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala index 989fdc5564d3..13edcd051768 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala @@ -851,7 +851,7 @@ private[hive] object HiveClientImpl { hiveTable.setFields(schema.asJava) } hiveTable.setPartCols(partCols.asJava) - conf.foreach(c => hiveTable.setOwner(c.getUser)) + conf.foreach { _ => hiveTable.setOwner(SessionState.get().getAuthenticator().getUserName()) } hiveTable.setCreateTime((table.createTime / 1000).toInt) hiveTable.setLastAccessTime((table.lastAccessTime / 1000).toInt) table.storage.locationUri.map(CatalogUtils.URIToString(_)).foreach { loc => From bec6b16c1900fe93def89cc5eb51cbef498196cb Mon Sep 17 00:00:00 2001 From: zero323 Date: Mon, 20 Mar 2017 10:58:30 -0700 Subject: [PATCH 078/512] [SPARK-19899][ML] Replace featuresCol with itemsCol in ml.fpm.FPGrowth ## What changes were proposed in this pull request? Replaces `featuresCol` `Param` with `itemsCol`. See [SPARK-19899](https://issues.apache.org/jira/browse/SPARK-19899). ## How was this patch tested? Manual tests. Existing unit tests. Author: zero323 Closes #17321 from zero323/SPARK-19899. --- .../org/apache/spark/ml/fpm/FPGrowth.scala | 35 +++++++++++++------ .../apache/spark/ml/fpm/FPGrowthSuite.scala | 14 ++++---- 2 files changed, 31 insertions(+), 18 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala index fa39dd954af5..e2bc270b38da 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala @@ -25,7 +25,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param._ -import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasPredictionCol} +import org.apache.spark.ml.param.shared.HasPredictionCol import org.apache.spark.ml.util._ import org.apache.spark.mllib.fpm.{AssociationRules => MLlibAssociationRules, FPGrowth => MLlibFPGrowth} @@ -37,7 +37,20 @@ import org.apache.spark.sql.types._ /** * Common params for FPGrowth and FPGrowthModel */ -private[fpm] trait FPGrowthParams extends Params with HasFeaturesCol with HasPredictionCol { +private[fpm] trait FPGrowthParams extends Params with HasPredictionCol { + + /** + * Items column name. + * Default: "items" + * @group param + */ + @Since("2.2.0") + val itemsCol: Param[String] = new Param[String](this, "itemsCol", "items column name") + setDefault(itemsCol -> "items") + + /** @group getParam */ + @Since("2.2.0") + def getItemsCol: String = $(itemsCol) /** * Minimal support level of the frequent pattern. [0.0, 1.0]. Any pattern that appears @@ -91,10 +104,10 @@ private[fpm] trait FPGrowthParams extends Params with HasFeaturesCol with HasPre */ @Since("2.2.0") protected def validateAndTransformSchema(schema: StructType): StructType = { - val inputType = schema($(featuresCol)).dataType + val inputType = schema($(itemsCol)).dataType require(inputType.isInstanceOf[ArrayType], s"The input column must be ArrayType, but got $inputType.") - SchemaUtils.appendColumn(schema, $(predictionCol), schema($(featuresCol)).dataType) + SchemaUtils.appendColumn(schema, $(predictionCol), schema($(itemsCol)).dataType) } } @@ -133,7 +146,7 @@ class FPGrowth @Since("2.2.0") ( /** @group setParam */ @Since("2.2.0") - def setFeaturesCol(value: String): this.type = set(featuresCol, value) + def setItemsCol(value: String): this.type = set(itemsCol, value) /** @group setParam */ @Since("2.2.0") @@ -146,8 +159,8 @@ class FPGrowth @Since("2.2.0") ( } private def genericFit[T: ClassTag](dataset: Dataset[_]): FPGrowthModel = { - val data = dataset.select($(featuresCol)) - val items = data.where(col($(featuresCol)).isNotNull).rdd.map(r => r.getSeq[T](0).toArray) + val data = dataset.select($(itemsCol)) + val items = data.where(col($(itemsCol)).isNotNull).rdd.map(r => r.getSeq[T](0).toArray) val mllibFP = new MLlibFPGrowth().setMinSupport($(minSupport)) if (isSet(numPartitions)) { mllibFP.setNumPartitions($(numPartitions)) @@ -156,7 +169,7 @@ class FPGrowth @Since("2.2.0") ( val rows = parentModel.freqItemsets.map(f => Row(f.items, f.freq)) val schema = StructType(Seq( - StructField("items", dataset.schema($(featuresCol)).dataType, nullable = false), + StructField("items", dataset.schema($(itemsCol)).dataType, nullable = false), StructField("freq", LongType, nullable = false))) val frequentItems = dataset.sparkSession.createDataFrame(rows, schema) copyValues(new FPGrowthModel(uid, frequentItems)).setParent(this) @@ -198,7 +211,7 @@ class FPGrowthModel private[ml] ( /** @group setParam */ @Since("2.2.0") - def setFeaturesCol(value: String): this.type = set(featuresCol, value) + def setItemsCol(value: String): this.type = set(itemsCol, value) /** @group setParam */ @Since("2.2.0") @@ -235,7 +248,7 @@ class FPGrowthModel private[ml] ( .collect().asInstanceOf[Array[(Seq[Any], Seq[Any])]] val brRules = dataset.sparkSession.sparkContext.broadcast(rules) - val dt = dataset.schema($(featuresCol)).dataType + val dt = dataset.schema($(itemsCol)).dataType // For each rule, examine the input items and summarize the consequents val predictUDF = udf((items: Seq[_]) => { if (items != null) { @@ -249,7 +262,7 @@ class FPGrowthModel private[ml] ( } else { Seq.empty }}, dt) - dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) + dataset.withColumn($(predictionCol), predictUDF(col($(itemsCol)))) } @Since("2.2.0") diff --git a/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala index 910d4b07d130..4603a618d2f9 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala @@ -34,7 +34,7 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul test("FPGrowth fit and transform with different data types") { Array(IntegerType, StringType, ShortType, LongType, ByteType).foreach { dt => - val data = dataset.withColumn("features", col("features").cast(ArrayType(dt))) + val data = dataset.withColumn("items", col("items").cast(ArrayType(dt))) val model = new FPGrowth().setMinSupport(0.5).fit(data) val generatedRules = model.setMinConfidence(0.5).associationRules val expectedRules = spark.createDataFrame(Seq( @@ -52,8 +52,8 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul (0, Array("1", "2"), Array.emptyIntArray), (0, Array("1", "2"), Array.emptyIntArray), (0, Array("1", "3"), Array(2)) - )).toDF("id", "features", "prediction") - .withColumn("features", col("features").cast(ArrayType(dt))) + )).toDF("id", "items", "prediction") + .withColumn("items", col("items").cast(ArrayType(dt))) .withColumn("prediction", col("prediction").cast(ArrayType(dt))) assert(expectedTransformed.collect().toSet.equals( transformed.collect().toSet)) @@ -79,7 +79,7 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul (1, Array("1", "2", "3", "5")), (2, Array("1", "2", "3", "4")), (3, null.asInstanceOf[Array[String]]) - )).toDF("id", "features") + )).toDF("id", "items") val model = new FPGrowth().setMinSupport(0.7).fit(dataset) val prediction = model.transform(df) assert(prediction.select("prediction").where("id=3").first().getSeq[String](0).isEmpty) @@ -108,11 +108,11 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul val dataset = spark.createDataFrame(Seq( Array("1", "3"), Array("2", "3") - ).map(Tuple1(_))).toDF("features") + ).map(Tuple1(_))).toDF("items") val model = new FPGrowth().fit(dataset) val prediction = model.transform( - spark.createDataFrame(Seq(Tuple1(Array("1", "2")))).toDF("features") + spark.createDataFrame(Seq(Tuple1(Array("1", "2")))).toDF("items") ).first().getAs[Seq[String]]("prediction") assert(prediction === Seq("3")) @@ -127,7 +127,7 @@ object FPGrowthSuite { (0, Array("1", "2")), (0, Array("1", "2")), (0, Array("1", "3")) - )).toDF("id", "features") + )).toDF("id", "items") } /** From c2d1761a57f5d175913284533b3d0417e8718688 Mon Sep 17 00:00:00 2001 From: Tyson Condie Date: Mon, 20 Mar 2017 17:18:59 -0700 Subject: [PATCH 079/512] [SPARK-19906][SS][DOCS] Documentation describing how to write queries to Kafka ## What changes were proposed in this pull request? Add documentation that describes how to write streaming and batch queries to Kafka. zsxwing tdas Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Tyson Condie Closes #17246 from tcondie/kafka-write-docs. --- .../structured-streaming-kafka-integration.md | 321 ++++++++++++++---- 1 file changed, 264 insertions(+), 57 deletions(-) diff --git a/docs/structured-streaming-kafka-integration.md b/docs/structured-streaming-kafka-integration.md index 522e66956867..217c1a91a16f 100644 --- a/docs/structured-streaming-kafka-integration.md +++ b/docs/structured-streaming-kafka-integration.md @@ -3,9 +3,9 @@ layout: global title: Structured Streaming + Kafka Integration Guide (Kafka broker version 0.10.0 or higher) --- -Structured Streaming integration for Kafka 0.10 to poll data from Kafka. +Structured Streaming integration for Kafka 0.10 to read data from and write data to Kafka. -### Linking +## Linking For Scala/Java applications using SBT/Maven project definitions, link your application with the following artifact: groupId = org.apache.spark @@ -15,40 +15,42 @@ For Scala/Java applications using SBT/Maven project definitions, link your appli For Python applications, you need to add this above library and its dependencies when deploying your application. See the [Deploying](#deploying) subsection below. -### Creating a Kafka Source Stream +## Reading Data from Kafka + +### Creating a Kafka Source for Streaming Queries
    {% highlight scala %} // Subscribe to 1 topic -val ds1 = spark +val df = spark .readStream .format("kafka") .option("kafka.bootstrap.servers", "host1:port1,host2:port2") .option("subscribe", "topic1") .load() -ds1.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") +df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") .as[(String, String)] // Subscribe to multiple topics -val ds2 = spark +val df = spark .readStream .format("kafka") .option("kafka.bootstrap.servers", "host1:port1,host2:port2") .option("subscribe", "topic1,topic2") .load() -ds2.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") +df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") .as[(String, String)] // Subscribe to a pattern -val ds3 = spark +val df = spark .readStream .format("kafka") .option("kafka.bootstrap.servers", "host1:port1,host2:port2") .option("subscribePattern", "topic.*") .load() -ds3.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") +df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") .as[(String, String)] {% endhighlight %} @@ -57,31 +59,31 @@ ds3.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") {% highlight java %} // Subscribe to 1 topic -Dataset ds1 = spark +DataFrame df = spark .readStream() .format("kafka") .option("kafka.bootstrap.servers", "host1:port1,host2:port2") .option("subscribe", "topic1") .load() -ds1.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") +df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") // Subscribe to multiple topics -Dataset ds2 = spark +DataFrame df = spark .readStream() .format("kafka") .option("kafka.bootstrap.servers", "host1:port1,host2:port2") .option("subscribe", "topic1,topic2") .load() -ds2.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") +df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") // Subscribe to a pattern -Dataset ds3 = spark +DataFrame df = spark .readStream() .format("kafka") .option("kafka.bootstrap.servers", "host1:port1,host2:port2") .option("subscribePattern", "topic.*") .load() -ds3.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") +df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") {% endhighlight %}
    @@ -89,37 +91,37 @@ ds3.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") {% highlight python %} # Subscribe to 1 topic -ds1 = spark - .readStream - .format("kafka") - .option("kafka.bootstrap.servers", "host1:port1,host2:port2") - .option("subscribe", "topic1") +df = spark \ + .readStream \ + .format("kafka") \ + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") \ + .option("subscribe", "topic1") \ .load() -ds1.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") +df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") # Subscribe to multiple topics -ds2 = spark - .readStream - .format("kafka") - .option("kafka.bootstrap.servers", "host1:port1,host2:port2") - .option("subscribe", "topic1,topic2") +df = spark \ + .readStream \ + .format("kafka") \ + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") \ + .option("subscribe", "topic1,topic2") \ .load() -ds2.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") +df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") # Subscribe to a pattern -ds3 = spark - .readStream - .format("kafka") - .option("kafka.bootstrap.servers", "host1:port1,host2:port2") - .option("subscribePattern", "topic.*") +df = spark \ + .readStream \ + .format("kafka") \ + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") \ + .option("subscribePattern", "topic.*") \ .load() -ds3.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") +df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") {% endhighlight %}
  • -### Creating a Kafka Source Batch +### Creating a Kafka Source for Batch Queries If you have a use case that is better suited to batch processing, you can create an Dataset/DataFrame for a defined range of offsets. @@ -128,17 +130,17 @@ you can create an Dataset/DataFrame for a defined range of offsets. {% highlight scala %} // Subscribe to 1 topic defaults to the earliest and latest offsets -val ds1 = spark +val df = spark .read .format("kafka") .option("kafka.bootstrap.servers", "host1:port1,host2:port2") .option("subscribe", "topic1") .load() -ds1.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") +df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") .as[(String, String)] // Subscribe to multiple topics, specifying explicit Kafka offsets -val ds2 = spark +val df = spark .read .format("kafka") .option("kafka.bootstrap.servers", "host1:port1,host2:port2") @@ -146,11 +148,11 @@ val ds2 = spark .option("startingOffsets", """{"topic1":{"0":23,"1":-2},"topic2":{"0":-2}}""") .option("endingOffsets", """{"topic1":{"0":50,"1":-1},"topic2":{"0":-1}}""") .load() -ds2.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") +df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") .as[(String, String)] // Subscribe to a pattern, at the earliest and latest offsets -val ds3 = spark +val df = spark .read .format("kafka") .option("kafka.bootstrap.servers", "host1:port1,host2:port2") @@ -158,7 +160,7 @@ val ds3 = spark .option("startingOffsets", "earliest") .option("endingOffsets", "latest") .load() -ds3.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") +df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") .as[(String, String)] {% endhighlight %} @@ -167,16 +169,16 @@ ds3.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") {% highlight java %} // Subscribe to 1 topic defaults to the earliest and latest offsets -Dataset ds1 = spark +DataFrame df = spark .read() .format("kafka") .option("kafka.bootstrap.servers", "host1:port1,host2:port2") .option("subscribe", "topic1") .load(); -ds1.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)"); +df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)"); // Subscribe to multiple topics, specifying explicit Kafka offsets -Dataset ds2 = spark +DataFrame df = spark .read() .format("kafka") .option("kafka.bootstrap.servers", "host1:port1,host2:port2") @@ -184,10 +186,10 @@ Dataset ds2 = spark .option("startingOffsets", "{\"topic1\":{\"0\":23,\"1\":-2},\"topic2\":{\"0\":-2}}") .option("endingOffsets", "{\"topic1\":{\"0\":50,\"1\":-1},\"topic2\":{\"0\":-1}}") .load(); -ds2.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)"); +df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)"); // Subscribe to a pattern, at the earliest and latest offsets -Dataset ds3 = spark +DataFrame df = spark .read() .format("kafka") .option("kafka.bootstrap.servers", "host1:port1,host2:port2") @@ -195,7 +197,7 @@ Dataset ds3 = spark .option("startingOffsets", "earliest") .option("endingOffsets", "latest") .load(); -ds3.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)"); +df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)"); {% endhighlight %} @@ -203,16 +205,16 @@ ds3.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)"); {% highlight python %} # Subscribe to 1 topic defaults to the earliest and latest offsets -ds1 = spark \ +df = spark \ .read \ .format("kafka") \ .option("kafka.bootstrap.servers", "host1:port1,host2:port2") \ .option("subscribe", "topic1") \ .load() -ds1.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") +df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") # Subscribe to multiple topics, specifying explicit Kafka offsets -ds2 = spark \ +df = spark \ .read \ .format("kafka") \ .option("kafka.bootstrap.servers", "host1:port1,host2:port2") \ @@ -220,10 +222,10 @@ ds2 = spark \ .option("startingOffsets", """{"topic1":{"0":23,"1":-2},"topic2":{"0":-2}}""") \ .option("endingOffsets", """{"topic1":{"0":50,"1":-1},"topic2":{"0":-1}}""") \ .load() -ds2.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") +df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") # Subscribe to a pattern, at the earliest and latest offsets -ds3 = spark \ +df = spark \ .read \ .format("kafka") \ .option("kafka.bootstrap.servers", "host1:port1,host2:port2") \ @@ -231,8 +233,7 @@ ds3 = spark \ .option("startingOffsets", "earliest") \ .option("endingOffsets", "latest") \ .load() -ds3.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") - +df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") {% endhighlight %} @@ -373,11 +374,213 @@ The following configurations are optional: +## Writing Data to Kafka + +Here, we describe the support for writing Streaming Queries and Batch Queries to Apache Kafka. Take note that +Apache Kafka only supports at least once write semantics. Consequently, when writing---either Streaming Queries +or Batch Queries---to Kafka, some records may be duplicated; this can happen, for example, if Kafka needs +to retry a message that was not acknowledged by a Broker, even though that Broker received and wrote the message record. +Structured Streaming cannot prevent such duplicates from occurring due to these Kafka write semantics. However, +if writing the query is successful, then you can assume that the query output was written at least once. A possible +solution to remove duplicates when reading the written data could be to introduce a primary (unique) key +that can be used to perform de-duplication when reading. + +The Dataframe being written to Kafka should have the following columns in schema: + + + + + + + + + + + + + + +
    ColumnType
    key (optional)string or binary
    value (required)string or binary
    topic (*optional)string
    +\* The topic column is required if the "topic" configuration option is not specified.
    + +The value column is the only required option. If a key column is not specified then +a ```null``` valued key column will be automatically added (see Kafka semantics on +how ```null``` valued key values are handled). If a topic column exists then its value +is used as the topic when writing the given row to Kafka, unless the "topic" configuration +option is set i.e., the "topic" configuration option overrides the topic column. + +The following options must be set for the Kafka sink +for both batch and streaming queries. + + + + + + + + +
    Optionvaluemeaning
    kafka.bootstrap.serversA comma-separated list of host:portThe Kafka "bootstrap.servers" configuration.
    + +The following configurations are optional: + + + + + + + + + + +
    Optionvaluedefaultquery typemeaning
    topicstringnonestreaming and batchSets the topic that all rows will be written to in Kafka. This option overrides any + topic column that may exist in the data.
    + +### Creating a Kafka Sink for Streaming Queries + +
    +
    +{% highlight scala %} + +// Write key-value data from a DataFrame to a specific Kafka topic specified in an option +val ds = df + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .writeStream + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .option("topic", "topic1") + .start() + +// Write key-value data from a DataFrame to Kafka using a topic specified in the data +val ds = df + .selectExpr("topic", "CAST(key AS STRING)", "CAST(value AS STRING)") + .writeStream + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .start() + +{% endhighlight %} +
    +
    +{% highlight java %} + +// Write key-value data from a DataFrame to a specific Kafka topic specified in an option +StreamingQuery ds = df + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .writeStream() + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .option("topic", "topic1") + .start() + +// Write key-value data from a DataFrame to Kafka using a topic specified in the data +StreamingQuery ds = df + .selectExpr("topic", "CAST(key AS STRING)", "CAST(value AS STRING)") + .writeStream() + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .start() + +{% endhighlight %} +
    +
    +{% highlight python %} + +# Write key-value data from a DataFrame to a specific Kafka topic specified in an option +ds = df \ + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") \ + .writeStream \ + .format("kafka") \ + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") \ + .option("topic", "topic1") \ + .start() + +# Write key-value data from a DataFrame to Kafka using a topic specified in the data +ds = df \ + .selectExpr("topic", "CAST(key AS STRING)", "CAST(value AS STRING)") \ + .writeStream \ + .format("kafka") \ + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") \ + .start() + +{% endhighlight %} +
    +
    + +### Writing the output of Batch Queries to Kafka + +
    +
    +{% highlight scala %} + +// Write key-value data from a DataFrame to a specific Kafka topic specified in an option +df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .write + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .option("topic", "topic1") + .save() + +// Write key-value data from a DataFrame to Kafka using a topic specified in the data +df.selectExpr("topic", "CAST(key AS STRING)", "CAST(value AS STRING)") + .write + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .save() + +{% endhighlight %} +
    +
    +{% highlight java %} + +// Write key-value data from a DataFrame to a specific Kafka topic specified in an option +df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .write() + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .option("topic", "topic1") + .save() + +// Write key-value data from a DataFrame to Kafka using a topic specified in the data +df.selectExpr("topic", "CAST(key AS STRING)", "CAST(value AS STRING)") + .write() + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .save() + +{% endhighlight %} +
    +
    +{% highlight python %} + +# Write key-value data from a DataFrame to a specific Kafka topic specified in an option +df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") \ + .write \ + .format("kafka") \ + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") \ + .option("topic", "topic1") \ + .save() + +# Write key-value data from a DataFrame to Kafka using a topic specified in the data +df.selectExpr("topic", "CAST(key AS STRING)", "CAST(value AS STRING)") \ + .write \ + .format("kafka") \ + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") \ + .save() + +{% endhighlight %} +
    +
    + + +## Kafka Specific Configurations + Kafka's own configurations can be set via `DataStreamReader.option` with `kafka.` prefix, e.g, -`stream.option("kafka.bootstrap.servers", "host:port")`. For possible kafkaParams, see -[Kafka consumer config docs](http://kafka.apache.org/documentation.html#newconsumerconfigs). +`stream.option("kafka.bootstrap.servers", "host:port")`. For possible kafka parameters, see +[Kafka consumer config docs](http://kafka.apache.org/documentation.html#newconsumerconfigs) for +parameters related to reading data, and [Kafka producer config docs](http://kafka.apache.org/documentation/#producerconfigs) +for parameters related to writing data. -Note that the following Kafka params cannot be set and the Kafka source will throw an exception: +Note that the following Kafka params cannot be set and the Kafka source or sink will throw an exception: - **group.id**: Kafka source will create a unique group id for each query automatically. - **auto.offset.reset**: Set the source option `startingOffsets` to specify @@ -389,11 +592,15 @@ Note that the following Kafka params cannot be set and the Kafka source will thr DataFrame operations to explicitly deserialize the keys. - **value.deserializer**: Values are always deserialized as byte arrays with ByteArrayDeserializer. Use DataFrame operations to explicitly deserialize the values. +- **key.serializer**: Keys are always serialized with ByteArraySerializer or StringSerializer. Use +DataFrame operations to explicitly serialize the keys into either strings or byte arrays. +- **value.serializer**: values are always serialized with ByteArraySerializer or StringSerializer. Use +DataFrame oeprations to explicitly serialize the values into either strings or byte arrays. - **enable.auto.commit**: Kafka source doesn't commit any offset. - **interceptor.classes**: Kafka source always read keys and values as byte arrays. It's not safe to use ConsumerInterceptor as it may break the query. -### Deploying +## Deploying As with any Spark applications, `spark-submit` is used to launch your application. `spark-sql-kafka-0-10_{{site.SCALA_BINARY_VERSION}}` and its dependencies can be directly added to `spark-submit` using `--packages`, such as, From 10691d36de902e3771af20aed40336b4f99de719 Mon Sep 17 00:00:00 2001 From: Zheng RuiFeng Date: Mon, 20 Mar 2017 18:25:59 -0700 Subject: [PATCH 080/512] [SPARK-19573][SQL] Make NaN/null handling consistent in approxQuantile ## What changes were proposed in this pull request? update `StatFunctions.multipleApproxQuantiles` to handle NaN/null ## How was this patch tested? existing tests and added tests Author: Zheng RuiFeng Closes #16971 from zhengruifeng/quantiles_nan. --- .../aggregate/ApproximatePercentile.scala | 3 +- .../sql/catalyst/util/QuantileSummaries.scala | 12 ++-- .../util/QuantileSummariesSuite.scala | 46 ++++++++++----- .../spark/sql/DataFrameStatFunctions.scala | 21 +++---- .../sql/execution/stat/StatFunctions.scala | 10 +++- .../apache/spark/sql/DataFrameStatSuite.scala | 57 ++++++++++++------- 6 files changed, 95 insertions(+), 54 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala index db062f1a543f..1ec2e4a9e931 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala @@ -245,7 +245,8 @@ object ApproximatePercentile { val result = new Array[Double](percentages.length) var i = 0 while (i < percentages.length) { - result(i) = summaries.query(percentages(i)) + // Since summaries.count != 0, the query here never return None. + result(i) = summaries.query(percentages(i)).get i += 1 } result diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala index 04f4ff2a9224..af543b04ba78 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala @@ -176,17 +176,19 @@ class QuantileSummaries( * @param quantile the target quantile * @return */ - def query(quantile: Double): Double = { + def query(quantile: Double): Option[Double] = { require(quantile >= 0 && quantile <= 1.0, "quantile should be in the range [0.0, 1.0]") require(headSampled.isEmpty, "Cannot operate on an uncompressed summary, call compress() first") + if (sampled.isEmpty) return None + if (quantile <= relativeError) { - return sampled.head.value + return Some(sampled.head.value) } if (quantile >= 1 - relativeError) { - return sampled.last.value + return Some(sampled.last.value) } // Target rank @@ -200,11 +202,11 @@ class QuantileSummaries( minRank += curSample.g val maxRank = minRank + curSample.delta if (maxRank - targetError <= rank && rank <= minRank + targetError) { - return curSample.value + return Some(curSample.value) } i += 1 } - sampled.last.value + Some(sampled.last.value) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/QuantileSummariesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/QuantileSummariesSuite.scala index 5e90970b1bb2..df579d5ec1dd 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/QuantileSummariesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/QuantileSummariesSuite.scala @@ -55,15 +55,19 @@ class QuantileSummariesSuite extends SparkFunSuite { } private def checkQuantile(quant: Double, data: Seq[Double], summary: QuantileSummaries): Unit = { - val approx = summary.query(quant) - // The rank of the approximation. - val rank = data.count(_ < approx) // has to be <, not <= to be exact - val lower = math.floor((quant - summary.relativeError) * data.size) - val upper = math.ceil((quant + summary.relativeError) * data.size) - val msg = - s"$rank not in [$lower $upper], requested quantile: $quant, approx returned: $approx" - assert(rank >= lower, msg) - assert(rank <= upper, msg) + if (data.nonEmpty) { + val approx = summary.query(quant).get + // The rank of the approximation. + val rank = data.count(_ < approx) // has to be <, not <= to be exact + val lower = math.floor((quant - summary.relativeError) * data.size) + val upper = math.ceil((quant + summary.relativeError) * data.size) + val msg = + s"$rank not in [$lower $upper], requested quantile: $quant, approx returned: $approx" + assert(rank >= lower, msg) + assert(rank <= upper, msg) + } else { + assert(summary.query(quant).isEmpty) + } } for { @@ -74,9 +78,9 @@ class QuantileSummariesSuite extends SparkFunSuite { test(s"Extremas with epsi=$epsi and seq=$seq_name, compression=$compression") { val s = buildSummary(data, epsi, compression) - val min_approx = s.query(0.0) + val min_approx = s.query(0.0).get assert(min_approx == data.min, s"Did not return the min: min=${data.min}, got $min_approx") - val max_approx = s.query(1.0) + val max_approx = s.query(1.0).get assert(max_approx == data.max, s"Did not return the max: max=${data.max}, got $max_approx") } @@ -100,6 +104,18 @@ class QuantileSummariesSuite extends SparkFunSuite { checkQuantile(0.1, data, s) checkQuantile(0.001, data, s) } + + test(s"Tests on empty data with epsi=$epsi and seq=$seq_name, compression=$compression") { + val emptyData = Seq.empty[Double] + val s = buildSummary(emptyData, epsi, compression) + assert(s.count == 0, s"Found count=${s.count} but data size=0") + assert(s.sampled.isEmpty, s"if QuantileSummaries is empty, sampled should be empty") + checkQuantile(0.9999, emptyData, s) + checkQuantile(0.9, emptyData, s) + checkQuantile(0.5, emptyData, s) + checkQuantile(0.1, emptyData, s) + checkQuantile(0.001, emptyData, s) + } } // Tests for merging procedure @@ -118,9 +134,9 @@ class QuantileSummariesSuite extends SparkFunSuite { val s1 = buildSummary(data1, epsi, compression) val s2 = buildSummary(data2, epsi, compression) val s = s1.merge(s2) - val min_approx = s.query(0.0) + val min_approx = s.query(0.0).get assert(min_approx == data.min, s"Did not return the min: min=${data.min}, got $min_approx") - val max_approx = s.query(1.0) + val max_approx = s.query(1.0).get assert(max_approx == data.max, s"Did not return the max: max=${data.max}, got $max_approx") checkQuantile(0.9999, data, s) checkQuantile(0.9, data, s) @@ -137,9 +153,9 @@ class QuantileSummariesSuite extends SparkFunSuite { val s1 = buildSummary(data11, epsi, compression) val s2 = buildSummary(data12, epsi, compression) val s = s1.merge(s2) - val min_approx = s.query(0.0) + val min_approx = s.query(0.0).get assert(min_approx == data.min, s"Did not return the min: min=${data.min}, got $min_approx") - val max_approx = s.query(1.0) + val max_approx = s.query(1.0).get assert(max_approx == data.max, s"Did not return the max: max=${data.max}, got $max_approx") checkQuantile(0.9999, data, s) checkQuantile(0.9, data, s) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala index bdcdf0c61ff3..c856d3099f6e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala @@ -64,7 +64,7 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { * @return the approximate quantiles at the given probabilities * * @note null and NaN values will be removed from the numerical column before calculation. If - * the dataframe is empty or all rows contain null or NaN, null is returned. + * the dataframe is empty or the column only contains null or NaN, an empty array is returned. * * @since 2.0.0 */ @@ -72,8 +72,7 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { col: String, probabilities: Array[Double], relativeError: Double): Array[Double] = { - val res = approxQuantile(Array(col), probabilities, relativeError) - Option(res).map(_.head).orNull + approxQuantile(Array(col), probabilities, relativeError).head } /** @@ -89,8 +88,8 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { * Note that values greater than 1 are accepted but give the same result as 1. * @return the approximate quantiles at the given probabilities of each column * - * @note Rows containing any null or NaN values will be removed before calculation. If - * the dataframe is empty or all rows contain null or NaN, null is returned. + * @note null and NaN values will be ignored in numerical columns before calculation. For + * columns only containing null or NaN values, an empty array is returned. * * @since 2.2.0 */ @@ -98,13 +97,11 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { cols: Array[String], probabilities: Array[Double], relativeError: Double): Array[Array[Double]] = { - // TODO: Update NaN/null handling to keep consistent with the single-column version - try { - StatFunctions.multipleApproxQuantiles(df.select(cols.map(col): _*).na.drop(), cols, - probabilities, relativeError).map(_.toArray).toArray - } catch { - case e: NoSuchElementException => null - } + StatFunctions.multipleApproxQuantiles( + df.select(cols.map(col): _*), + cols, + probabilities, + relativeError).map(_.toArray).toArray } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala index c3d8859cb7a9..1debad03c93f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala @@ -54,6 +54,9 @@ object StatFunctions extends Logging { * Note that values greater than 1 are accepted but give the same result as 1. * * @return for each column, returns the requested approximations + * + * @note null and NaN values will be ignored in numerical columns before calculation. For + * a column only containing null or NaN values, an empty array is returned. */ def multipleApproxQuantiles( df: DataFrame, @@ -78,7 +81,10 @@ object StatFunctions extends Logging { def apply(summaries: Array[QuantileSummaries], row: Row): Array[QuantileSummaries] = { var i = 0 while (i < summaries.length) { - summaries(i) = summaries(i).insert(row.getDouble(i)) + if (!row.isNullAt(i)) { + val v = row.getDouble(i) + if (!v.isNaN) summaries(i) = summaries(i).insert(v) + } i += 1 } summaries @@ -91,7 +97,7 @@ object StatFunctions extends Logging { } val summaries = df.select(columns: _*).rdd.aggregate(emptySummaries)(apply, merge) - summaries.map { summary => probabilities.map(summary.query) } + summaries.map { summary => probabilities.flatMap(summary.query) } } /** Calculate the Pearson Correlation Coefficient for the given columns */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index d0910e618a04..97890a035a62 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -171,15 +171,6 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext { df.stat.approxQuantile(Array("singles", "doubles"), Array(q1, q2), -1.0) } assert(e2.getMessage.contains("Relative Error must be non-negative")) - - // return null if the dataset is empty - val res1 = df.selectExpr("*").limit(0) - .stat.approxQuantile("singles", Array(q1, q2), epsilons.head) - assert(res1 === null) - - val res2 = df.selectExpr("*").limit(0) - .stat.approxQuantile(Array("singles", "doubles"), Array(q1, q2), epsilons.head) - assert(res2 === null) } test("approximate quantile 2: test relativeError greater than 1 return the same result as 1") { @@ -214,20 +205,48 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext { val q1 = 0.5 val q2 = 0.8 val epsilon = 0.1 - val rows = spark.sparkContext.parallelize(Seq(Row(Double.NaN, 1.0), Row(1.0, 1.0), - Row(-1.0, Double.NaN), Row(Double.NaN, Double.NaN), Row(null, null), Row(null, 1.0), - Row(-1.0, null), Row(Double.NaN, null))) + val rows = spark.sparkContext.parallelize(Seq(Row(Double.NaN, 1.0, Double.NaN), + Row(1.0, -1.0, null), Row(-1.0, Double.NaN, null), Row(Double.NaN, Double.NaN, null), + Row(null, null, Double.NaN), Row(null, 1.0, null), Row(-1.0, null, Double.NaN), + Row(Double.NaN, null, null))) val schema = StructType(Seq(StructField("input1", DoubleType, nullable = true), - StructField("input2", DoubleType, nullable = true))) + StructField("input2", DoubleType, nullable = true), + StructField("input3", DoubleType, nullable = true))) val dfNaN = spark.createDataFrame(rows, schema) - val resNaN = dfNaN.stat.approxQuantile("input1", Array(q1, q2), epsilon) - assert(resNaN.count(_.isNaN) === 0) - assert(resNaN.count(_ == null) === 0) - val resNaN2 = dfNaN.stat.approxQuantile(Array("input1", "input2"), + val resNaN1 = dfNaN.stat.approxQuantile("input1", Array(q1, q2), epsilon) + assert(resNaN1.count(_.isNaN) === 0) + assert(resNaN1.count(_ == null) === 0) + + val resNaN2 = dfNaN.stat.approxQuantile("input2", Array(q1, q2), epsilon) + assert(resNaN2.count(_.isNaN) === 0) + assert(resNaN2.count(_ == null) === 0) + + val resNaN3 = dfNaN.stat.approxQuantile("input3", Array(q1, q2), epsilon) + assert(resNaN3.isEmpty) + + val resNaNAll = dfNaN.stat.approxQuantile(Array("input1", "input2", "input3"), Array(q1, q2), epsilon) - assert(resNaN2.flatten.count(_.isNaN) === 0) - assert(resNaN2.flatten.count(_ == null) === 0) + assert(resNaNAll.flatten.count(_.isNaN) === 0) + assert(resNaNAll.flatten.count(_ == null) === 0) + + assert(resNaN1(0) === resNaNAll(0)(0)) + assert(resNaN1(1) === resNaNAll(0)(1)) + assert(resNaN2(0) === resNaNAll(1)(0)) + assert(resNaN2(1) === resNaNAll(1)(1)) + + // return empty array for columns only containing null or NaN values + assert(resNaNAll(2).isEmpty) + + // return empty array if the dataset is empty + val res1 = dfNaN.selectExpr("*").limit(0) + .stat.approxQuantile("input1", Array(q1, q2), epsilon) + assert(res1.isEmpty) + + val res2 = dfNaN.selectExpr("*").limit(0) + .stat.approxQuantile(Array("input1", "input2"), Array(q1, q2), epsilon) + assert(res2(0).isEmpty) + assert(res2(1).isEmpty) } test("crosstab") { From e9c91badce64731ffd3e53cbcd9f044a7593e6b8 Mon Sep 17 00:00:00 2001 From: wangzhenhua Date: Tue, 21 Mar 2017 10:43:17 +0800 Subject: [PATCH 081/512] [SPARK-20010][SQL] Sort information is lost after sort merge join ## What changes were proposed in this pull request? After sort merge join for inner join, now we only keep left key ordering. However, after inner join, right key has the same value and order as left key. So if we need another smj on right key, we will unnecessarily add a sort which causes additional cost. As a more complicated example, A join B on A.key = B.key join C on B.key = C.key join D on A.key = D.key. We will unnecessarily add a sort on B.key when join {A, B} and C, and add a sort on A.key when join {A, B, C} and D. To fix this, we need to propagate all sorted information (equivalent expressions) from bottom up through `outputOrdering` and `SortOrder`. ## How was this patch tested? Test cases are added. Author: wangzhenhua Closes #17339 from wzhfy/sortEnhance. --- .../sql/catalyst/analysis/Analyzer.scala | 4 +-- .../SubstituteUnresolvedOrdinals.scala | 2 +- .../spark/sql/catalyst/dsl/package.scala | 4 +-- .../sql/catalyst/expressions/SortOrder.scala | 21 +++++++++++-- .../sql/catalyst/parser/AstBuilder.scala | 2 +- .../scala/org/apache/spark/sql/Column.scala | 8 ++--- .../exchange/EnsureRequirements.scala | 2 +- .../execution/joins/SortMergeJoinExec.scala | 26 ++++++++++++++-- .../spark/sql/execution/PlannerSuite.scala | 30 ++++++++++++++++++- 9 files changed, 81 insertions(+), 18 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 8cf407382619..574f91b09912 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -966,9 +966,9 @@ class Analyzer( case s @ Sort(orders, global, child) if orders.exists(_.child.isInstanceOf[UnresolvedOrdinal]) => val newOrders = orders map { - case s @ SortOrder(UnresolvedOrdinal(index), direction, nullOrdering) => + case s @ SortOrder(UnresolvedOrdinal(index), direction, nullOrdering, _) => if (index > 0 && index <= child.output.size) { - SortOrder(child.output(index - 1), direction, nullOrdering) + SortOrder(child.output(index - 1), direction, nullOrdering, Set.empty) } else { s.failAnalysis( s"ORDER BY position $index is not in select list " + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/SubstituteUnresolvedOrdinals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/SubstituteUnresolvedOrdinals.scala index af0a565f73ae..38a3d3de1288 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/SubstituteUnresolvedOrdinals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/SubstituteUnresolvedOrdinals.scala @@ -36,7 +36,7 @@ class SubstituteUnresolvedOrdinals(conf: CatalystConf) extends Rule[LogicalPlan] def apply(plan: LogicalPlan): LogicalPlan = plan transform { case s: Sort if conf.orderByOrdinal && s.order.exists(o => isIntLiteral(o.child)) => val newOrders = s.order.map { - case order @ SortOrder(ordinal @ Literal(index: Int, IntegerType), _, _) => + case order @ SortOrder(ordinal @ Literal(index: Int, IntegerType), _, _, _) => val newOrdinal = withOrigin(ordinal.origin)(UnresolvedOrdinal(index)) withOrigin(order.origin)(order.copy(child = newOrdinal)) case other => other diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 35ca2a0aa53a..75bf780d4142 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -109,9 +109,9 @@ package object dsl { def cast(to: DataType): Expression = Cast(expr, to) def asc: SortOrder = SortOrder(expr, Ascending) - def asc_nullsLast: SortOrder = SortOrder(expr, Ascending, NullsLast) + def asc_nullsLast: SortOrder = SortOrder(expr, Ascending, NullsLast, Set.empty) def desc: SortOrder = SortOrder(expr, Descending) - def desc_nullsFirst: SortOrder = SortOrder(expr, Descending, NullsFirst) + def desc_nullsFirst: SortOrder = SortOrder(expr, Descending, NullsFirst, Set.empty) def as(alias: String): NamedExpression = Alias(expr, alias)() def as(alias: Symbol): NamedExpression = Alias(expr, alias.name)() } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala index 3bebd552ef51..abcb9a2b939b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala @@ -53,8 +53,15 @@ case object NullsLast extends NullOrdering{ /** * An expression that can be used to sort a tuple. This class extends expression primarily so that * transformations over expression will descend into its child. + * `sameOrderExpressions` is a set of expressions with the same sort order as the child. It is + * derived from equivalence relation in an operator, e.g. left/right keys of an inner sort merge + * join. */ -case class SortOrder(child: Expression, direction: SortDirection, nullOrdering: NullOrdering) +case class SortOrder( + child: Expression, + direction: SortDirection, + nullOrdering: NullOrdering, + sameOrderExpressions: Set[Expression]) extends UnaryExpression with Unevaluable { /** Sort order is not foldable because we don't have an eval for it. */ @@ -75,11 +82,19 @@ case class SortOrder(child: Expression, direction: SortDirection, nullOrdering: override def sql: String = child.sql + " " + direction.sql + " " + nullOrdering.sql def isAscending: Boolean = direction == Ascending + + def satisfies(required: SortOrder): Boolean = { + (sameOrderExpressions + child).exists(required.child.semanticEquals) && + direction == required.direction && nullOrdering == required.nullOrdering + } } object SortOrder { - def apply(child: Expression, direction: SortDirection): SortOrder = { - new SortOrder(child, direction, direction.defaultNullOrdering) + def apply( + child: Expression, + direction: SortDirection, + sameOrderExpressions: Set[Expression] = Set.empty): SortOrder = { + new SortOrder(child, direction, direction.defaultNullOrdering, sameOrderExpressions) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 4c9fb2ec2774..cd238e05d410 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -1229,7 +1229,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { } else { direction.defaultNullOrdering } - SortOrder(expression(ctx.expression), direction, nullOrdering) + SortOrder(expression(ctx.expression), direction, nullOrdering, Set.empty) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 38029552d13b..ae0703513cf4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -1037,7 +1037,7 @@ class Column(val expr: Expression) extends Logging { * @group expr_ops * @since 2.1.0 */ - def desc_nulls_first: Column = withExpr { SortOrder(expr, Descending, NullsFirst) } + def desc_nulls_first: Column = withExpr { SortOrder(expr, Descending, NullsFirst, Set.empty) } /** * Returns a descending ordering used in sorting, where null values appear after non-null values. @@ -1052,7 +1052,7 @@ class Column(val expr: Expression) extends Logging { * @group expr_ops * @since 2.1.0 */ - def desc_nulls_last: Column = withExpr { SortOrder(expr, Descending, NullsLast) } + def desc_nulls_last: Column = withExpr { SortOrder(expr, Descending, NullsLast, Set.empty) } /** * Returns an ascending ordering used in sorting. @@ -1082,7 +1082,7 @@ class Column(val expr: Expression) extends Logging { * @group expr_ops * @since 2.1.0 */ - def asc_nulls_first: Column = withExpr { SortOrder(expr, Ascending, NullsFirst) } + def asc_nulls_first: Column = withExpr { SortOrder(expr, Ascending, NullsFirst, Set.empty) } /** * Returns an ordering used in sorting, where null values appear after non-null values. @@ -1097,7 +1097,7 @@ class Column(val expr: Expression) extends Logging { * @group expr_ops * @since 2.1.0 */ - def asc_nulls_last: Column = withExpr { SortOrder(expr, Ascending, NullsLast) } + def asc_nulls_last: Column = withExpr { SortOrder(expr, Ascending, NullsLast, Set.empty) } /** * Prints the expression to the console for debugging purpose. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index f17049949aa4..b91d07744255 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -241,7 +241,7 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { } else { requiredOrdering.zip(child.outputOrdering).forall { case (requiredOrder, childOutputOrder) => - requiredOrder.semanticEquals(childOutputOrder) + childOutputOrder.satisfies(requiredOrder) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index 02f4f55c7999..c6aae1a4db2e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -81,17 +81,37 @@ case class SortMergeJoinExec( ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil override def outputOrdering: Seq[SortOrder] = joinType match { + // For inner join, orders of both sides keys should be kept. + case Inner => + val leftKeyOrdering = getKeyOrdering(leftKeys, left.outputOrdering) + val rightKeyOrdering = getKeyOrdering(rightKeys, right.outputOrdering) + leftKeyOrdering.zip(rightKeyOrdering).map { case (lKey, rKey) => + // Also add the right key and its `sameOrderExpressions` + SortOrder(lKey.child, Ascending, lKey.sameOrderExpressions + rKey.child ++ rKey + .sameOrderExpressions) + } // For left and right outer joins, the output is ordered by the streamed input's join keys. - case LeftOuter => requiredOrders(leftKeys) - case RightOuter => requiredOrders(rightKeys) + case LeftOuter => getKeyOrdering(leftKeys, left.outputOrdering) + case RightOuter => getKeyOrdering(rightKeys, right.outputOrdering) // There are null rows in both streams, so there is no order. case FullOuter => Nil - case _: InnerLike | LeftExistence(_) => requiredOrders(leftKeys) + case LeftExistence(_) => getKeyOrdering(leftKeys, left.outputOrdering) case x => throw new IllegalArgumentException( s"${getClass.getSimpleName} should not take $x as the JoinType") } + /** + * For SMJ, child's output must have been sorted on key or expressions with the same order as + * key, so we can get ordering for key from child's output ordering. + */ + private def getKeyOrdering(keys: Seq[Expression], childOutputOrdering: Seq[SortOrder]) + : Seq[SortOrder] = { + keys.zip(childOutputOrdering).map { case (key, childOrder) => + SortOrder(key, Ascending, childOrder.sameOrderExpressions + childOrder.child - key) + } + } + override def requiredChildOrdering: Seq[Seq[SortOrder]] = requiredOrders(leftKeys) :: requiredOrders(rightKeys) :: Nil diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index f2232fc489b7..4d155d538d63 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -477,14 +477,18 @@ class PlannerSuite extends SharedSQLContext { private val exprA = Literal(1) private val exprB = Literal(2) + private val exprC = Literal(3) private val orderingA = SortOrder(exprA, Ascending) private val orderingB = SortOrder(exprB, Ascending) + private val orderingC = SortOrder(exprC, Ascending) private val planA = DummySparkPlan(outputOrdering = Seq(orderingA), outputPartitioning = HashPartitioning(exprA :: Nil, 5)) private val planB = DummySparkPlan(outputOrdering = Seq(orderingB), outputPartitioning = HashPartitioning(exprB :: Nil, 5)) + private val planC = DummySparkPlan(outputOrdering = Seq(orderingC), + outputPartitioning = HashPartitioning(exprC :: Nil, 5)) - assert(orderingA != orderingB) + assert(orderingA != orderingB && orderingA != orderingC && orderingB != orderingC) private def assertSortRequirementsAreSatisfied( childPlan: SparkPlan, @@ -508,6 +512,30 @@ class PlannerSuite extends SharedSQLContext { } } + test("EnsureRequirements skips sort when either side of join keys is required after inner SMJ") { + val innerSmj = SortMergeJoinExec(exprA :: Nil, exprB :: Nil, Inner, None, planA, planB) + // Both left and right keys should be sorted after the SMJ. + Seq(orderingA, orderingB).foreach { ordering => + assertSortRequirementsAreSatisfied( + childPlan = innerSmj, + requiredOrdering = Seq(ordering), + shouldHaveSort = false) + } + } + + test("EnsureRequirements skips sort when key order of a parent SMJ is propagated from its " + + "child SMJ") { + val childSmj = SortMergeJoinExec(exprA :: Nil, exprB :: Nil, Inner, None, planA, planB) + val parentSmj = SortMergeJoinExec(exprB :: Nil, exprC :: Nil, Inner, None, childSmj, planC) + // After the second SMJ, exprA, exprB and exprC should all be sorted. + Seq(orderingA, orderingB, orderingC).foreach { ordering => + assertSortRequirementsAreSatisfied( + childPlan = parentSmj, + requiredOrdering = Seq(ordering), + shouldHaveSort = false) + } + } + test("EnsureRequirements for sort operator after left outer sort merge join") { // Only left key is sorted after left outer SMJ (thus doesn't need a sort). val leftSmj = SortMergeJoinExec(exprA :: Nil, exprB :: Nil, LeftOuter, None, planA, planB) From 0ec1db5475f1a7839bdbf0d9cffe93ce6970a7fe Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Tue, 21 Mar 2017 11:17:34 +0800 Subject: [PATCH 082/512] [SPARK-19980][SQL] Add NULL checks in Bean serializer ## What changes were proposed in this pull request? A Bean serializer in `ExpressionEncoder` could change values when Beans having NULL. A concrete example is as follows; ``` scala> :paste class Outer extends Serializable { private var cls: Inner = _ def setCls(c: Inner): Unit = cls = c def getCls(): Inner = cls } class Inner extends Serializable { private var str: String = _ def setStr(s: String): Unit = str = str def getStr(): String = str } scala> Seq("""{"cls":null}""", """{"cls": {"str":null}}""").toDF().write.text("data") scala> val encoder = Encoders.bean(classOf[Outer]) scala> val schema = encoder.schema scala> val df = spark.read.schema(schema).json("data").as[Outer](encoder) scala> df.show +------+ | cls| +------+ |[null]| | null| +------+ scala> df.map(x => x)(encoder).show() +------+ | cls| +------+ |[null]| |[null]| // <-- Value changed +------+ ``` This is because the Bean serializer does not have the NULL-check expressions that the serializer of Scala's product types has. Actually, this value change does not happen in Scala's product types; ``` scala> :paste case class Outer(cls: Inner) case class Inner(str: String) scala> val encoder = Encoders.product[Outer] scala> val schema = encoder.schema scala> val df = spark.read.schema(schema).json("data").as[Outer](encoder) scala> df.show +------+ | cls| +------+ |[null]| | null| +------+ scala> df.map(x => x)(encoder).show() +------+ | cls| +------+ |[null]| | null| +------+ ``` This pr added the NULL-check expressions in Bean serializer along with the serializer of Scala's product types. ## How was this patch tested? Added tests in `JavaDatasetSuite`. Author: Takeshi Yamamuro Closes #17347 from maropu/SPARK-19980. --- .../sql/catalyst/JavaTypeInference.scala | 11 +++++++++-- .../apache/spark/sql/JavaDatasetSuite.java | 19 +++++++++++++++++++ 2 files changed, 28 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index 4ff87edde139..9d4617dda555 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -343,7 +343,11 @@ object JavaTypeInference { */ def serializerFor(beanClass: Class[_]): CreateNamedStruct = { val inputObject = BoundReference(0, ObjectType(beanClass), nullable = true) - serializerFor(inputObject, TypeToken.of(beanClass)).asInstanceOf[CreateNamedStruct] + val nullSafeInput = AssertNotNull(inputObject, Seq("top level input bean")) + serializerFor(nullSafeInput, TypeToken.of(beanClass)) match { + case expressions.If(_, _, s: CreateNamedStruct) => s + case other => CreateNamedStruct(expressions.Literal("value") :: other :: Nil) + } } private def serializerFor(inputObject: Expression, typeToken: TypeToken[_]): Expression = { @@ -427,7 +431,7 @@ object JavaTypeInference { case other => val properties = getJavaBeanReadableAndWritableProperties(other) - CreateNamedStruct(properties.flatMap { p => + val nonNullOutput = CreateNamedStruct(properties.flatMap { p => val fieldName = p.getName val fieldType = typeToken.method(p.getReadMethod).getReturnType val fieldValue = Invoke( @@ -436,6 +440,9 @@ object JavaTypeInference { inferExternalType(fieldType.getRawType)) expressions.Literal(fieldName) :: serializerFor(fieldValue, fieldType) :: Nil }) + + val nullOutput = expressions.Literal.create(null, nonNullOutput.dataType) + expressions.If(IsNull(inputObject), nullOutput, nonNullOutput) } } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index ca9e5ad2ea86..ffb4c6273ff8 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -1380,4 +1380,23 @@ public void testCircularReferenceBean3() { CircularReference4Bean bean = new CircularReference4Bean(); spark.createDataset(Arrays.asList(bean), Encoders.bean(CircularReference4Bean.class)); } + + @Test(expected = RuntimeException.class) + public void testNullInTopLevelBean() { + NestedSmallBean bean = new NestedSmallBean(); + // We cannot set null in top-level bean + spark.createDataset(Arrays.asList(bean, null), Encoders.bean(NestedSmallBean.class)); + } + + @Test + public void testSerializeNull() { + NestedSmallBean bean = new NestedSmallBean(); + Encoder encoder = Encoders.bean(NestedSmallBean.class); + List beans = Arrays.asList(bean); + Dataset ds1 = spark.createDataset(beans, encoder); + Assert.assertEquals(beans, ds1.collectAsList()); + Dataset ds2 = + ds1.map((MapFunction) b -> b, encoder); + Assert.assertEquals(beans, ds2.collectAsList()); + } } From 7fa116f8fc77906202217c0cd2f9718a4e62632b Mon Sep 17 00:00:00 2001 From: Michael Allman Date: Tue, 21 Mar 2017 11:51:22 +0800 Subject: [PATCH 083/512] [SPARK-17204][CORE] Fix replicated off heap storage (Jira: https://issues.apache.org/jira/browse/SPARK-17204) ## What changes were proposed in this pull request? There are a couple of bugs in the `BlockManager` with respect to support for replicated off-heap storage. First, the locally-stored off-heap byte buffer is disposed of when it is replicated. It should not be. Second, the replica byte buffers are stored as heap byte buffers instead of direct byte buffers even when the storage level memory mode is off-heap. This PR addresses both of these problems. ## How was this patch tested? `BlockManagerReplicationSuite` was enhanced to fill in the coverage gaps. It now fails if either of the bugs in this PR exist. Author: Michael Allman Closes #16499 from mallman/spark-17204-replicated_off_heap_storage. --- .../apache/spark/storage/BlockManager.scala | 23 ++++++-- .../apache/spark/storage/StorageUtils.scala | 52 ++++++++++++++++--- .../spark/util/ByteBufferInputStream.scala | 8 +-- .../spark/util/io/ChunkedByteBuffer.scala | 27 ++++++++-- .../BlockManagerReplicationSuite.scala | 20 +++++-- 5 files changed, 105 insertions(+), 25 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 45b73380806d..245d94ac4f8b 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -317,6 +317,9 @@ private[spark] class BlockManager( /** * Put the block locally, using the given storage level. + * + * '''Important!''' Callers must not mutate or release the data buffer underlying `bytes`. Doing + * so may corrupt or change the data stored by the `BlockManager`. */ override def putBlockData( blockId: BlockId, @@ -755,6 +758,9 @@ private[spark] class BlockManager( /** * Put a new block of serialized bytes to the block manager. * + * '''Important!''' Callers must not mutate or release the data buffer underlying `bytes`. Doing + * so may corrupt or change the data stored by the `BlockManager`. + * * @param encrypt If true, asks the block manager to encrypt the data block before storing, * when I/O encryption is enabled. This is required for blocks that have been * read from unencrypted sources, since all the BlockManager read APIs @@ -773,7 +779,7 @@ private[spark] class BlockManager( if (encrypt && securityManager.ioEncryptionKey.isDefined) { try { val data = bytes.toByteBuffer - val in = new ByteBufferInputStream(data, true) + val in = new ByteBufferInputStream(data) val byteBufOut = new ByteBufferOutputStream(data.remaining()) val out = CryptoStreamUtils.createCryptoOutputStream(byteBufOut, conf, securityManager.ioEncryptionKey.get) @@ -800,6 +806,9 @@ private[spark] class BlockManager( * * If the block already exists, this method will not overwrite it. * + * '''Important!''' Callers must not mutate or release the data buffer underlying `bytes`. Doing + * so may corrupt or change the data stored by the `BlockManager`. + * * @param keepReadLock if true, this method will hold the read lock when it returns (even if the * block already exists). If false, this method will hold no locks when it * returns. @@ -843,7 +852,15 @@ private[spark] class BlockManager( false } } else { - memoryStore.putBytes(blockId, size, level.memoryMode, () => bytes) + val memoryMode = level.memoryMode + memoryStore.putBytes(blockId, size, memoryMode, () => { + if (memoryMode == MemoryMode.OFF_HEAP && + bytes.chunks.exists(buffer => !buffer.isDirect)) { + bytes.copy(Platform.allocateDirectBuffer) + } else { + bytes + } + }) } if (!putSucceeded && level.useDisk) { logWarning(s"Persisting block $blockId to disk instead.") @@ -1048,7 +1065,7 @@ private[spark] class BlockManager( try { replicate(blockId, bytesToReplicate, level, remoteClassTag) } finally { - bytesToReplicate.dispose() + bytesToReplicate.unmap() } logDebug("Put block %s remotely took %s" .format(blockId, Utils.getUsedTimeMs(remoteStartTime))) diff --git a/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala b/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala index e12f2e6095d5..5efdd23f79a2 100644 --- a/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala +++ b/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala @@ -236,22 +236,60 @@ class StorageStatus(val blockManagerId: BlockManagerId, val maxMem: Long) { /** Helper methods for storage-related objects. */ private[spark] object StorageUtils extends Logging { + // Ewwww... Reflection!!! See the unmap method for justification + private val memoryMappedBufferFileDescriptorField = { + val mappedBufferClass = classOf[java.nio.MappedByteBuffer] + val fdField = mappedBufferClass.getDeclaredField("fd") + fdField.setAccessible(true) + fdField + } /** - * Attempt to clean up a ByteBuffer if it is memory-mapped. This uses an *unsafe* Sun API that - * might cause errors if one attempts to read from the unmapped buffer, but it's better than - * waiting for the GC to find it because that could lead to huge numbers of open files. There's - * unfortunately no standard API to do this. + * Attempt to clean up a ByteBuffer if it is direct or memory-mapped. This uses an *unsafe* Sun + * API that will cause errors if one attempts to read from the disposed buffer. However, neither + * the bytes allocated to direct buffers nor file descriptors opened for memory-mapped buffers put + * pressure on the garbage collector. Waiting for garbage collection may lead to the depletion of + * off-heap memory or huge numbers of open files. There's unfortunately no standard API to + * manually dispose of these kinds of buffers. + * + * See also [[unmap]] */ def dispose(buffer: ByteBuffer): Unit = { if (buffer != null && buffer.isInstanceOf[MappedByteBuffer]) { - logTrace(s"Unmapping $buffer") - if (buffer.asInstanceOf[DirectBuffer].cleaner() != null) { - buffer.asInstanceOf[DirectBuffer].cleaner().clean() + logTrace(s"Disposing of $buffer") + cleanDirectBuffer(buffer.asInstanceOf[DirectBuffer]) + } + } + + /** + * Attempt to unmap a ByteBuffer if it is memory-mapped. This uses an *unsafe* Sun API that will + * cause errors if one attempts to read from the unmapped buffer. However, the file descriptors of + * memory-mapped buffers do not put pressure on the garbage collector. Waiting for garbage + * collection may lead to huge numbers of open files. There's unfortunately no standard API to + * manually unmap memory-mapped buffers. + * + * See also [[dispose]] + */ + def unmap(buffer: ByteBuffer): Unit = { + if (buffer != null && buffer.isInstanceOf[MappedByteBuffer]) { + // Note that direct buffers are instances of MappedByteBuffer. As things stand in Java 8, the + // JDK does not provide a public API to distinguish between direct buffers and memory-mapped + // buffers. As an alternative, we peek beneath the curtains and look for a non-null file + // descriptor in mappedByteBuffer + if (memoryMappedBufferFileDescriptorField.get(buffer) != null) { + logTrace(s"Unmapping $buffer") + cleanDirectBuffer(buffer.asInstanceOf[DirectBuffer]) } } } + private def cleanDirectBuffer(buffer: DirectBuffer) = { + val cleaner = buffer.cleaner() + if (cleaner != null) { + cleaner.clean() + } + } + /** * Update the given list of RDDInfo with the given list of storage statuses. * This method overwrites the old values stored in the RDDInfo's. diff --git a/core/src/main/scala/org/apache/spark/util/ByteBufferInputStream.scala b/core/src/main/scala/org/apache/spark/util/ByteBufferInputStream.scala index dce2ac63a664..50dc948e6c41 100644 --- a/core/src/main/scala/org/apache/spark/util/ByteBufferInputStream.scala +++ b/core/src/main/scala/org/apache/spark/util/ByteBufferInputStream.scala @@ -23,11 +23,10 @@ import java.nio.ByteBuffer import org.apache.spark.storage.StorageUtils /** - * Reads data from a ByteBuffer, and optionally cleans it up using StorageUtils.dispose() - * at the end of the stream (e.g. to close a memory-mapped file). + * Reads data from a ByteBuffer. */ private[spark] -class ByteBufferInputStream(private var buffer: ByteBuffer, dispose: Boolean = false) +class ByteBufferInputStream(private var buffer: ByteBuffer) extends InputStream { override def read(): Int = { @@ -72,9 +71,6 @@ class ByteBufferInputStream(private var buffer: ByteBuffer, dispose: Boolean = f */ private def cleanUp() { if (buffer != null) { - if (dispose) { - StorageUtils.dispose(buffer) - } buffer = null } } diff --git a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala index 7572cac39317..1667516663b3 100644 --- a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala +++ b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala @@ -86,7 +86,11 @@ private[spark] class ChunkedByteBuffer(var chunks: Array[ByteBuffer]) { } /** - * Copy this buffer into a new ByteBuffer. + * Convert this buffer to a ByteBuffer. If this buffer is backed by a single chunk, its underlying + * data will not be copied. Instead, it will be duplicated. If this buffer is backed by multiple + * chunks, the data underlying this buffer will be copied into a new byte buffer. As a result, it + * is suggested to use this method only if the caller does not need to manage the memory + * underlying this buffer. * * @throws UnsupportedOperationException if this buffer's size exceeds the max ByteBuffer size. */ @@ -132,10 +136,10 @@ private[spark] class ChunkedByteBuffer(var chunks: Array[ByteBuffer]) { } /** - * Attempt to clean up a ByteBuffer if it is memory-mapped. This uses an *unsafe* Sun API that - * might cause errors if one attempts to read from the unmapped buffer, but it's better than - * waiting for the GC to find it because that could lead to huge numbers of open files. There's - * unfortunately no standard API to do this. + * Attempt to clean up any ByteBuffer in this ChunkedByteBuffer which is direct or memory-mapped. + * See [[StorageUtils.dispose]] for more information. + * + * See also [[unmap]] */ def dispose(): Unit = { if (!disposed) { @@ -143,6 +147,19 @@ private[spark] class ChunkedByteBuffer(var chunks: Array[ByteBuffer]) { disposed = true } } + + /** + * Attempt to unmap any ByteBuffer in this ChunkedByteBuffer if it is memory-mapped. See + * [[StorageUtils.unmap]] for more information. + * + * See also [[dispose]] + */ + def unmap(): Unit = { + if (!disposed) { + chunks.foreach(StorageUtils.unmap) + disposed = true + } + } } /** diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala index 75dc04038deb..d907add920c8 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala @@ -374,7 +374,8 @@ trait BlockManagerReplicationBehavior extends SparkFunSuite // Put the block into one of the stores val blockId = new TestBlockId( "block-with-" + storageLevel.description.replace(" ", "-").toLowerCase) - stores(0).putSingle(blockId, new Array[Byte](blockSize), storageLevel) + val testValue = Array.fill[Byte](blockSize)(1) + stores(0).putSingle(blockId, testValue, storageLevel) // Assert that master know two locations for the block val blockLocations = master.getLocations(blockId).map(_.executorId).toSet @@ -386,12 +387,23 @@ trait BlockManagerReplicationBehavior extends SparkFunSuite testStore => blockLocations.contains(testStore.blockManagerId.executorId) }.foreach { testStore => val testStoreName = testStore.blockManagerId.executorId - assert( - testStore.getLocalValues(blockId).isDefined, s"$blockId was not found in $testStoreName") - testStore.releaseLock(blockId) + val blockResultOpt = testStore.getLocalValues(blockId) + assert(blockResultOpt.isDefined, s"$blockId was not found in $testStoreName") + val localValues = blockResultOpt.get.data.toSeq + assert(localValues.size == 1) + assert(localValues.head === testValue) assert(master.getLocations(blockId).map(_.executorId).toSet.contains(testStoreName), s"master does not have status for ${blockId.name} in $testStoreName") + val memoryStore = testStore.memoryStore + if (memoryStore.contains(blockId) && !storageLevel.deserialized) { + memoryStore.getBytes(blockId).get.chunks.foreach { byteBuffer => + assert(storageLevel.useOffHeap == byteBuffer.isDirect, + s"memory mode ${storageLevel.memoryMode} is not compatible with " + + byteBuffer.getClass.getSimpleName) + } + } + val blockStatus = master.getBlockStatus(blockId)(testStore.blockManagerId) // Assert that block status in the master for this store has expected storage level From 21e366aea5a7f49e42e78dce06ff6b3ee1e36f06 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Tue, 21 Mar 2017 12:17:26 +0800 Subject: [PATCH 084/512] [SPARK-19912][SQL] String literals should be escaped for Hive metastore partition pruning ## What changes were proposed in this pull request? Since current `HiveShim`'s `convertFilters` does not escape the string literals. There exists the following correctness issues. This PR aims to return the correct result and also shows the more clear exception message. **BEFORE** ```scala scala> Seq((1, "p1", "q1"), (2, "p1\" and q=\"q1", "q2")).toDF("a", "p", "q").write.partitionBy("p", "q").saveAsTable("t1") scala> spark.table("t1").filter($"p" === "p1\" and q=\"q1").select($"a").show +---+ | a| +---+ +---+ scala> spark.table("t1").filter($"p" === "'\"").select($"a").show java.lang.RuntimeException: Caught Hive MetaException attempting to get partition metadata by filter from ... ``` **AFTER** ```scala scala> spark.table("t1").filter($"p" === "p1\" and q=\"q1").select($"a").show +---+ | a| +---+ | 2| +---+ scala> spark.table("t1").filter($"p" === "'\"").select($"a").show java.lang.UnsupportedOperationException: Partition filter cannot have both `"` and `'` characters ``` ## How was this patch tested? Pass the Jenkins test with new test cases. Author: Dongjoon Hyun Closes #17266 from dongjoon-hyun/SPARK-19912. --- .../apache/spark/sql/hive/client/HiveShim.scala | 16 ++++++++++++++-- .../spark/sql/hive/client/FiltersSuite.scala | 5 +++++ .../spark/sql/hive/execution/SQLQuerySuite.scala | 16 ++++++++++++++++ 3 files changed, 35 insertions(+), 2 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala index 76568f599078..d55c41e5c9f2 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala @@ -596,13 +596,24 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { s"$v ${op.symbol} ${a.name}" case op @ BinaryComparison(a: Attribute, Literal(v, _: StringType)) if !varcharKeys.contains(a.name) => - s"""${a.name} ${op.symbol} "$v"""" + s"""${a.name} ${op.symbol} ${quoteStringLiteral(v.toString)}""" case op @ BinaryComparison(Literal(v, _: StringType), a: Attribute) if !varcharKeys.contains(a.name) => - s""""$v" ${op.symbol} ${a.name}""" + s"""${quoteStringLiteral(v.toString)} ${op.symbol} ${a.name}""" }.mkString(" and ") } + private def quoteStringLiteral(str: String): String = { + if (!str.contains("\"")) { + s""""$str"""" + } else if (!str.contains("'")) { + s"""'$str'""" + } else { + throw new UnsupportedOperationException( + """Partition filter cannot have both `"` and `'` characters""") + } + } + override def getPartitionsByFilter( hive: Hive, table: Table, @@ -611,6 +622,7 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { // Hive getPartitionsByFilter() takes a string that represents partition // predicates like "str_key=\"value\" and int_key=1 ..." val filter = convertFilters(table, predicates) + val partitions = if (filter.isEmpty) { getAllPartitionsMethod.invoke(hive, table).asInstanceOf[JSet[Partition]] diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/FiltersSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/FiltersSuite.scala index cd96c85f3e20..031c1a5ec0ec 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/FiltersSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/FiltersSuite.scala @@ -65,6 +65,11 @@ class FiltersSuite extends SparkFunSuite with Logging { (Literal("") === a("varchar", StringType)) :: Nil, "") + filterTest("SPARK-19912 String literals should be escaped for Hive metastore partition pruning", + (a("stringcol", StringType) === Literal("p1\" and q=\"q1")) :: + (Literal("p2\" and q=\"q2") === a("stringcol", StringType)) :: Nil, + """stringcol = 'p1" and q="q1' and 'p2" and q="q2' = stringcol""") + private def filterTest(name: String, filters: Seq[Expression], result: String) = { test(name) { val converted = shim.convertFilters(testTable, filters) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 236135dcff52..55ff4bb115e5 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -2057,4 +2057,20 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } } } + + test("SPARK-19912 String literals should be escaped for Hive metastore partition pruning") { + withTable("spark_19912") { + Seq( + (1, "p1", "q1"), + (2, "'", "q2"), + (3, "\"", "q3"), + (4, "p1\" and q=\"q1", "q4") + ).toDF("a", "p", "q").write.partitionBy("p", "q").saveAsTable("spark_19912") + + val table = spark.table("spark_19912") + checkAnswer(table.filter($"p" === "'").select($"a"), Row(2)) + checkAnswer(table.filter($"p" === "\"").select($"a"), Row(3)) + checkAnswer(table.filter($"p" === "p1\" and q=\"q1").select($"a"), Row(4)) + } + } } From 68d65fae71e475ad811a9716098aca03a2af9532 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 20 Mar 2017 21:43:14 -0700 Subject: [PATCH 085/512] [SPARK-19949][SQL] unify bad record handling in CSV and JSON ## What changes were proposed in this pull request? Currently JSON and CSV have exactly the same logic about handling bad records, this PR tries to abstract it and put it in a upper level to reduce code duplication. The overall idea is, we make the JSON and CSV parser to throw a BadRecordException, then the upper level, FailureSafeParser, handles bad records according to the parse mode. Behavior changes: 1. with PERMISSIVE mode, if the number of tokens doesn't match the schema, previously CSV parser will treat it as a legal record and parse as many tokens as possible. After this PR, we treat it as an illegal record, and put the raw record string in a special column, but we still parse as many tokens as possible. 2. all logging is removed as they are not very useful in practice. ## How was this patch tested? existing tests Author: Wenchen Fan Author: hyukjinkwon Author: Wenchen Fan Closes #17315 from cloud-fan/bad-record2. --- R/pkg/inst/tests/testthat/test_sparkSQL.R | 5 +- .../expressions/jsonExpressions.scala | 4 +- .../spark/sql/catalyst/json/JSONOptions.scala | 2 +- .../sql/catalyst/json/JacksonParser.scala | 122 +---------- .../sql/catalyst/util/FailureSafeParser.scala | 80 +++++++ .../apache/spark/sql/DataFrameReader.scala | 23 +- .../datasources/csv/CSVDataSource.scala | 17 +- .../datasources/csv/CSVFileFormat.scala | 7 +- .../datasources/csv/CSVOptions.scala | 2 +- .../datasources/csv/UnivocityParser.scala | 197 ++++++------------ .../datasources/json/JsonDataSource.scala | 31 ++- .../datasources/json/JsonFileFormat.scala | 7 +- .../execution/datasources/csv/CSVSuite.scala | 2 +- .../datasources/json/JsonSuite.scala | 8 +- 14 files changed, 222 insertions(+), 285 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/FailureSafeParser.scala diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index cbc3569795d9..394d1a04e09c 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -1370,9 +1370,8 @@ test_that("column functions", { # passing option df <- as.DataFrame(list(list("col" = "{\"date\":\"21/10/2014\"}"))) schema2 <- structType(structField("date", "date")) - expect_error(tryCatch(collect(select(df, from_json(df$col, schema2))), - error = function(e) { stop(e) }), - paste0(".*(java.lang.NumberFormatException: For input string:).*")) + s <- collect(select(df, from_json(df$col, schema2))) + expect_equal(s[[1]][[1]], NA) s <- collect(select(df, from_json(df$col, schema2, dateFormat = "dd/MM/yyyy"))) expect_is(s[[1]][[1]]$date, "Date") expect_equal(as.character(s[[1]][[1]]$date), "2014-10-21") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index e4e08a8665a5..08af5522d822 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.json._ -import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, ParseModes} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, BadRecordException, GenericArrayData, ParseModes} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils @@ -583,7 +583,7 @@ case class JsonToStructs( CreateJacksonParser.utf8String, identity[UTF8String])) } catch { - case _: SparkSQLJsonProcessingException => null + case _: BadRecordException => null } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala index 5f222ec602c9..355c26afa6f0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala @@ -65,7 +65,7 @@ private[sql] class JSONOptions( val allowBackslashEscapingAnyCharacter = parameters.get("allowBackslashEscapingAnyCharacter").map(_.toBoolean).getOrElse(false) val compressionCodec = parameters.get("compression").map(CompressionCodecs.getCodecClassName) - private val parseMode = parameters.getOrElse("mode", "PERMISSIVE") + val parseMode = parameters.getOrElse("mode", "PERMISSIVE") val columnNameOfCorruptRecord = parameters.getOrElse("columnNameOfCorruptRecord", defaultColumnNameOfCorruptRecord) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala index 9b80c0fc87c9..fdb7d88d5bd7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala @@ -32,17 +32,14 @@ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils -private[sql] class SparkSQLJsonProcessingException(msg: String) extends RuntimeException(msg) - /** * Constructs a parser for a given schema that translates a json string to an [[InternalRow]]. */ class JacksonParser( schema: StructType, - options: JSONOptions) extends Logging { + val options: JSONOptions) extends Logging { import JacksonUtils._ - import ParseModes._ import com.fasterxml.jackson.core.JsonToken._ // A `ValueConverter` is responsible for converting a value from `JsonParser` @@ -55,108 +52,6 @@ class JacksonParser( private val factory = new JsonFactory() options.setJacksonOptions(factory) - private val emptyRow: Seq[InternalRow] = Seq(new GenericInternalRow(schema.length)) - - private val corruptFieldIndex = schema.getFieldIndex(options.columnNameOfCorruptRecord) - corruptFieldIndex.foreach { corrFieldIndex => - require(schema(corrFieldIndex).dataType == StringType) - require(schema(corrFieldIndex).nullable) - } - - @transient - private[this] var isWarningPrinted: Boolean = false - - @transient - private def printWarningForMalformedRecord(record: () => UTF8String): Unit = { - def sampleRecord: String = { - if (options.wholeFile) { - "" - } else { - s"Sample record: ${record()}\n" - } - } - - def footer: String = { - s"""Code example to print all malformed records (scala): - |=================================================== - |// The corrupted record exists in column ${options.columnNameOfCorruptRecord}. - |val parsedJson = spark.read.json("/path/to/json/file/test.json") - | - """.stripMargin - } - - if (options.permissive) { - logWarning( - s"""Found at least one malformed record. The JSON reader will replace - |all malformed records with placeholder null in current $PERMISSIVE_MODE parser mode. - |To find out which corrupted records have been replaced with null, please use the - |default inferred schema instead of providing a custom schema. - | - |${sampleRecord ++ footer} - | - """.stripMargin) - } else if (options.dropMalformed) { - logWarning( - s"""Found at least one malformed record. The JSON reader will drop - |all malformed records in current $DROP_MALFORMED_MODE parser mode. To find out which - |corrupted records have been dropped, please switch the parser mode to $PERMISSIVE_MODE - |mode and use the default inferred schema. - | - |${sampleRecord ++ footer} - | - """.stripMargin) - } - } - - @transient - private def printWarningIfWholeFile(): Unit = { - if (options.wholeFile && corruptFieldIndex.isDefined) { - logWarning( - s"""Enabling wholeFile mode and defining columnNameOfCorruptRecord may result - |in very large allocations or OutOfMemoryExceptions being raised. - | - """.stripMargin) - } - } - - /** - * This function deals with the cases it fails to parse. This function will be called - * when exceptions are caught during converting. This functions also deals with `mode` option. - */ - private def failedRecord(record: () => UTF8String): Seq[InternalRow] = { - corruptFieldIndex match { - case _ if options.failFast => - if (options.wholeFile) { - throw new SparkSQLJsonProcessingException("Malformed line in FAILFAST mode") - } else { - throw new SparkSQLJsonProcessingException(s"Malformed line in FAILFAST mode: ${record()}") - } - - case _ if options.dropMalformed => - if (!isWarningPrinted) { - printWarningForMalformedRecord(record) - isWarningPrinted = true - } - Nil - - case None => - if (!isWarningPrinted) { - printWarningForMalformedRecord(record) - isWarningPrinted = true - } - emptyRow - - case Some(corruptIndex) => - if (!isWarningPrinted) { - printWarningIfWholeFile() - isWarningPrinted = true - } - val row = new GenericInternalRow(schema.length) - row.update(corruptIndex, record()) - Seq(row) - } - } - /** * Create a converter which converts the JSON documents held by the `JsonParser` * to a value according to a desired schema. This is a wrapper for the method @@ -239,7 +134,7 @@ class JacksonParser( lowerCaseValue.equals("-inf")) { value.toFloat } else { - throw new SparkSQLJsonProcessingException(s"Cannot parse $value as FloatType.") + throw new RuntimeException(s"Cannot parse $value as FloatType.") } } @@ -259,7 +154,7 @@ class JacksonParser( lowerCaseValue.equals("-inf")) { value.toDouble } else { - throw new SparkSQLJsonProcessingException(s"Cannot parse $value as DoubleType.") + throw new RuntimeException(s"Cannot parse $value as DoubleType.") } } @@ -391,9 +286,8 @@ class JacksonParser( case token => // We cannot parse this token based on the given data type. So, we throw a - // SparkSQLJsonProcessingException and this exception will be caught by - // `parse` method. - throw new SparkSQLJsonProcessingException( + // RuntimeException and this exception will be caught by `parse` method. + throw new RuntimeException( s"Failed to parse a value for data type $dataType (current token: $token).") } @@ -466,14 +360,14 @@ class JacksonParser( parser.nextToken() match { case null => Nil case _ => rootConverter.apply(parser) match { - case null => throw new SparkSQLJsonProcessingException("Root converter returned null") + case null => throw new RuntimeException("Root converter returned null") case rows => rows } } } } catch { - case _: JsonProcessingException | _: SparkSQLJsonProcessingException => - failedRecord(() => recordLiteral(record)) + case e @ (_: RuntimeException | _: JsonProcessingException) => + throw BadRecordException(() => recordLiteral(record), () => None, e) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/FailureSafeParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/FailureSafeParser.scala new file mode 100644 index 000000000000..e8da10d65ecb --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/FailureSafeParser.scala @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow +import org.apache.spark.sql.types.StructType +import org.apache.spark.unsafe.types.UTF8String + +class FailureSafeParser[IN]( + rawParser: IN => Seq[InternalRow], + mode: String, + schema: StructType, + columnNameOfCorruptRecord: String) { + + private val corruptFieldIndex = schema.getFieldIndex(columnNameOfCorruptRecord) + private val actualSchema = StructType(schema.filterNot(_.name == columnNameOfCorruptRecord)) + private val resultRow = new GenericInternalRow(schema.length) + private val nullResult = new GenericInternalRow(schema.length) + + // This function takes 2 parameters: an optional partial result, and the bad record. If the given + // schema doesn't contain a field for corrupted record, we just return the partial result or a + // row with all fields null. If the given schema contains a field for corrupted record, we will + // set the bad record to this field, and set other fields according to the partial result or null. + private val toResultRow: (Option[InternalRow], () => UTF8String) => InternalRow = { + if (corruptFieldIndex.isDefined) { + (row, badRecord) => { + var i = 0 + while (i < actualSchema.length) { + val from = actualSchema(i) + resultRow(schema.fieldIndex(from.name)) = row.map(_.get(i, from.dataType)).orNull + i += 1 + } + resultRow(corruptFieldIndex.get) = badRecord() + resultRow + } + } else { + (row, _) => row.getOrElse(nullResult) + } + } + + def parse(input: IN): Iterator[InternalRow] = { + try { + rawParser.apply(input).toIterator.map(row => toResultRow(Some(row), () => null)) + } catch { + case e: BadRecordException if ParseModes.isPermissiveMode(mode) => + Iterator(toResultRow(e.partialResult(), e.record)) + case _: BadRecordException if ParseModes.isDropMalformedMode(mode) => + Iterator.empty + case e: BadRecordException => throw e.cause + } + } +} + +/** + * Exception thrown when the underlying parser meet a bad record and can't parse it. + * @param record a function to return the record that cause the parser to fail + * @param partialResult a function that returns an optional row, which is the partial result of + * parsing this bad record. + * @param cause the actual exception about why the record is bad and can't be parsed. + */ +case class BadRecordException( + record: () => UTF8String, + partialResult: () => Option[InternalRow], + cause: Throwable) extends Exception(cause) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 88fbfb4c92a0..767a636d7073 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -27,6 +27,7 @@ import org.apache.spark.Partition import org.apache.spark.annotation.InterfaceStability import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JSONOptions} +import org.apache.spark.sql.catalyst.util.FailureSafeParser import org.apache.spark.sql.execution.LogicalRDD import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.csv._ @@ -382,11 +383,18 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { } verifyColumnNameOfCorruptRecord(schema, parsedOptions.columnNameOfCorruptRecord) + val actualSchema = + StructType(schema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord)) val createParser = CreateJacksonParser.string _ val parsed = jsonDataset.rdd.mapPartitions { iter => - val parser = new JacksonParser(schema, parsedOptions) - iter.flatMap(parser.parse(_, createParser, UTF8String.fromString)) + val rawParser = new JacksonParser(actualSchema, parsedOptions) + val parser = new FailureSafeParser[String]( + input => rawParser.parse(input, createParser, UTF8String.fromString), + parsedOptions.parseMode, + schema, + parsedOptions.columnNameOfCorruptRecord) + iter.flatMap(parser.parse) } Dataset.ofRows( @@ -435,14 +443,21 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { } verifyColumnNameOfCorruptRecord(schema, parsedOptions.columnNameOfCorruptRecord) + val actualSchema = + StructType(schema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord)) val linesWithoutHeader: RDD[String] = maybeFirstLine.map { firstLine => filteredLines.rdd.mapPartitions(CSVUtils.filterHeaderLine(_, firstLine, parsedOptions)) }.getOrElse(filteredLines.rdd) val parsed = linesWithoutHeader.mapPartitions { iter => - val parser = new UnivocityParser(schema, parsedOptions) - iter.flatMap(line => parser.parse(line)) + val rawParser = new UnivocityParser(actualSchema, parsedOptions) + val parser = new FailureSafeParser[String]( + input => Seq(rawParser.parse(input)), + parsedOptions.parseMode, + schema, + parsedOptions.columnNameOfCorruptRecord) + iter.flatMap(parser.parse) } Dataset.ofRows( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala index 35ff924f27ce..63af18ec5b8e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala @@ -49,7 +49,7 @@ abstract class CSVDataSource extends Serializable { conf: Configuration, file: PartitionedFile, parser: UnivocityParser, - parsedOptions: CSVOptions): Iterator[InternalRow] + schema: StructType): Iterator[InternalRow] /** * Infers the schema from `inputPaths` files. @@ -115,17 +115,17 @@ object TextInputCSVDataSource extends CSVDataSource { conf: Configuration, file: PartitionedFile, parser: UnivocityParser, - parsedOptions: CSVOptions): Iterator[InternalRow] = { + schema: StructType): Iterator[InternalRow] = { val lines = { val linesReader = new HadoopFileLinesReader(file, conf) Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => linesReader.close())) linesReader.map { line => - new String(line.getBytes, 0, line.getLength, parsedOptions.charset) + new String(line.getBytes, 0, line.getLength, parser.options.charset) } } - val shouldDropHeader = parsedOptions.headerFlag && file.start == 0 - UnivocityParser.parseIterator(lines, shouldDropHeader, parser) + val shouldDropHeader = parser.options.headerFlag && file.start == 0 + UnivocityParser.parseIterator(lines, shouldDropHeader, parser, schema) } override def infer( @@ -192,11 +192,12 @@ object WholeFileCSVDataSource extends CSVDataSource { conf: Configuration, file: PartitionedFile, parser: UnivocityParser, - parsedOptions: CSVOptions): Iterator[InternalRow] = { + schema: StructType): Iterator[InternalRow] = { UnivocityParser.parseStream( CodecStreams.createInputStreamWithCloseResource(conf, file.filePath), - parsedOptions.headerFlag, - parser) + parser.options.headerFlag, + parser, + schema) } override def infer( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala index 29c41455279e..eef43c7629c1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala @@ -113,8 +113,11 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { (file: PartitionedFile) => { val conf = broadcastedHadoopConf.value.value - val parser = new UnivocityParser(dataSchema, requiredSchema, parsedOptions) - CSVDataSource(parsedOptions).readFile(conf, file, parser, parsedOptions) + val parser = new UnivocityParser( + StructType(dataSchema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord)), + StructType(requiredSchema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord)), + parsedOptions) + CSVDataSource(parsedOptions).readFile(conf, file, parser, requiredSchema) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala index 2632e87971d6..f6c6b6f56cd9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala @@ -82,7 +82,7 @@ class CSVOptions( val delimiter = CSVUtils.toChar( parameters.getOrElse("sep", parameters.getOrElse("delimiter", ","))) - private val parseMode = parameters.getOrElse("mode", "PERMISSIVE") + val parseMode = parameters.getOrElse("mode", "PERMISSIVE") val charset = parameters.getOrElse("encoding", parameters.getOrElse("charset", StandardCharsets.UTF_8.name())) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala index e42ea3fa391f..263f77e11c4d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala @@ -30,14 +30,14 @@ import com.univocity.parsers.csv.CsvParser import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.GenericInternalRow -import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.catalyst.util.{BadRecordException, DateTimeUtils, FailureSafeParser} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String class UnivocityParser( schema: StructType, requiredSchema: StructType, - private val options: CSVOptions) extends Logging { + val options: CSVOptions) extends Logging { require(requiredSchema.toSet.subsetOf(schema.toSet), "requiredSchema should be the subset of schema.") @@ -46,39 +46,26 @@ class UnivocityParser( // A `ValueConverter` is responsible for converting the given value to a desired type. private type ValueConverter = String => Any - private val corruptFieldIndex = schema.getFieldIndex(options.columnNameOfCorruptRecord) - corruptFieldIndex.foreach { corrFieldIndex => - require(schema(corrFieldIndex).dataType == StringType) - require(schema(corrFieldIndex).nullable) - } - - private val dataSchema = StructType(schema.filter(_.name != options.columnNameOfCorruptRecord)) - private val tokenizer = new CsvParser(options.asParserSettings) - private var numMalformedRecords = 0 - private val row = new GenericInternalRow(requiredSchema.length) - // In `PERMISSIVE` parse mode, we should be able to put the raw malformed row into the field - // specified in `columnNameOfCorruptRecord`. The raw input is retrieved by this method. - private def getCurrentInput(): String = tokenizer.getContext.currentParsedContent().stripLineEnd + // Retrieve the raw record string. + private def getCurrentInput: UTF8String = { + UTF8String.fromString(tokenizer.getContext.currentParsedContent().stripLineEnd) + } - // This parser loads an `tokenIndexArr`-th position value in input tokens, - // then put the value in `row(rowIndexArr)`. + // This parser first picks some tokens from the input tokens, according to the required schema, + // then parse these tokens and put the values in a row, with the order specified by the required + // schema. // // For example, let's say there is CSV data as below: // // a,b,c // 1,2,A // - // Also, let's say `columnNameOfCorruptRecord` is set to "_unparsed", `header` is `true` - // by user and the user selects "c", "b", "_unparsed" and "a" fields. In this case, we need - // to map those values below: - // - // required schema - ["c", "b", "_unparsed", "a"] - // CSV data schema - ["a", "b", "c"] - // required CSV data schema - ["c", "b", "a"] + // So the CSV data schema is: ["a", "b", "c"] + // And let's say the required schema is: ["c", "b"] // // with the input tokens, // @@ -86,45 +73,12 @@ class UnivocityParser( // // Each input token is placed in each output row's position by mapping these. In this case, // - // output row - ["A", 2, null, 1] - // - // In more details, - // - `valueConverters`, input tokens - CSV data schema - // `valueConverters` keeps the positions of input token indices (by its index) to each - // value's converter (by its value) in an order of CSV data schema. In this case, - // [string->int, string->int, string->string]. - // - // - `tokenIndexArr`, input tokens - required CSV data schema - // `tokenIndexArr` keeps the positions of input token indices (by its index) to reordered - // fields given the required CSV data schema (by its value). In this case, [2, 1, 0]. - // - // - `rowIndexArr`, input tokens - required schema - // `rowIndexArr` keeps the positions of input token indices (by its index) to reordered - // field indices given the required schema (by its value). In this case, [0, 1, 3]. + // output row - ["A", 2] private val valueConverters: Array[ValueConverter] = - dataSchema.map(f => makeConverter(f.name, f.dataType, f.nullable, options)).toArray - - // Only used to create both `tokenIndexArr` and `rowIndexArr`. This variable means - // the fields that we should try to convert. - private val reorderedFields = if (options.dropMalformed) { - // If `dropMalformed` is enabled, then it needs to parse all the values - // so that we can decide which row is malformed. - requiredSchema ++ schema.filterNot(requiredSchema.contains(_)) - } else { - requiredSchema - } + schema.map(f => makeConverter(f.name, f.dataType, f.nullable, options)).toArray private val tokenIndexArr: Array[Int] = { - reorderedFields - .filter(_.name != options.columnNameOfCorruptRecord) - .map(f => dataSchema.indexOf(f)).toArray - } - - private val rowIndexArr: Array[Int] = if (corruptFieldIndex.isDefined) { - val corrFieldIndex = corruptFieldIndex.get - reorderedFields.indices.filter(_ != corrFieldIndex).toArray - } else { - reorderedFields.indices.toArray + requiredSchema.map(f => schema.indexOf(f)).toArray } /** @@ -205,7 +159,7 @@ class UnivocityParser( } case _: StringType => (d: String) => - nullSafeDatum(d, name, nullable, options)(UTF8String.fromString(_)) + nullSafeDatum(d, name, nullable, options)(UTF8String.fromString) case udt: UserDefinedType[_] => (datum: String) => makeConverter(name, udt.sqlType, nullable, options) @@ -233,81 +187,41 @@ class UnivocityParser( * Parses a single CSV string and turns it into either one resulting row or no row (if the * the record is malformed). */ - def parse(input: String): Option[InternalRow] = convert(tokenizer.parseLine(input)) - - private def convert(tokens: Array[String]): Option[InternalRow] = { - convertWithParseMode(tokens) { tokens => - var i: Int = 0 - while (i < tokenIndexArr.length) { - // It anyway needs to try to parse since it decides if this row is malformed - // or not after trying to cast in `DROPMALFORMED` mode even if the casted - // value is not stored in the row. - val from = tokenIndexArr(i) - val to = rowIndexArr(i) - val value = valueConverters(from).apply(tokens(from)) - if (i < requiredSchema.length) { - row(to) = value - } - i += 1 - } - row - } - } - - private def convertWithParseMode( - tokens: Array[String])(convert: Array[String] => InternalRow): Option[InternalRow] = { - if (options.dropMalformed && dataSchema.length != tokens.length) { - if (numMalformedRecords < options.maxMalformedLogPerPartition) { - logWarning(s"Dropping malformed line: ${tokens.mkString(options.delimiter.toString)}") - } - if (numMalformedRecords == options.maxMalformedLogPerPartition - 1) { - logWarning( - s"More than ${options.maxMalformedLogPerPartition} malformed records have been " + - "found on this partition. Malformed records from now on will not be logged.") + def parse(input: String): InternalRow = convert(tokenizer.parseLine(input)) + + private def convert(tokens: Array[String]): InternalRow = { + if (tokens.length != schema.length) { + // If the number of tokens doesn't match the schema, we should treat it as a malformed record. + // However, we still have chance to parse some of the tokens, by adding extra null tokens in + // the tail if the number is smaller, or by dropping extra tokens if the number is larger. + val checkedTokens = if (schema.length > tokens.length) { + tokens ++ new Array[String](schema.length - tokens.length) + } else { + tokens.take(schema.length) } - numMalformedRecords += 1 - None - } else if (options.failFast && dataSchema.length != tokens.length) { - throw new RuntimeException(s"Malformed line in FAILFAST mode: " + - s"${tokens.mkString(options.delimiter.toString)}") - } else { - // If a length of parsed tokens is not equal to expected one, it makes the length the same - // with the expected. If the length is shorter, it adds extra tokens in the tail. - // If longer, it drops extra tokens. - // - // TODO: Revisit this; if a length of tokens does not match an expected length in the schema, - // we probably need to treat it as a malformed record. - // See an URL below for related discussions: - // https://github.com/apache/spark/pull/16928#discussion_r102657214 - val checkedTokens = if (options.permissive && dataSchema.length != tokens.length) { - if (dataSchema.length > tokens.length) { - tokens ++ new Array[String](dataSchema.length - tokens.length) - } else { - tokens.take(dataSchema.length) + def getPartialResult(): Option[InternalRow] = { + try { + Some(convert(checkedTokens)) + } catch { + case _: BadRecordException => None } - } else { - tokens } - + throw BadRecordException( + () => getCurrentInput, + getPartialResult, + new RuntimeException("Malformed CSV record")) + } else { try { - Some(convert(checkedTokens)) + var i = 0 + while (i < requiredSchema.length) { + val from = tokenIndexArr(i) + row(i) = valueConverters(from).apply(tokens(from)) + i += 1 + } + row } catch { - case NonFatal(e) if options.permissive => - val row = new GenericInternalRow(requiredSchema.length) - corruptFieldIndex.foreach(row(_) = UTF8String.fromString(getCurrentInput())) - Some(row) - case NonFatal(e) if options.dropMalformed => - if (numMalformedRecords < options.maxMalformedLogPerPartition) { - logWarning("Parse exception. " + - s"Dropping malformed line: ${tokens.mkString(options.delimiter.toString)}") - } - if (numMalformedRecords == options.maxMalformedLogPerPartition - 1) { - logWarning( - s"More than ${options.maxMalformedLogPerPartition} malformed records have been " + - "found on this partition. Malformed records from now on will not be logged.") - } - numMalformedRecords += 1 - None + case NonFatal(e) => + throw BadRecordException(() => getCurrentInput, () => None, e) } } } @@ -331,10 +245,16 @@ private[csv] object UnivocityParser { def parseStream( inputStream: InputStream, shouldDropHeader: Boolean, - parser: UnivocityParser): Iterator[InternalRow] = { + parser: UnivocityParser, + schema: StructType): Iterator[InternalRow] = { val tokenizer = parser.tokenizer + val safeParser = new FailureSafeParser[Array[String]]( + input => Seq(parser.convert(input)), + parser.options.parseMode, + schema, + parser.options.columnNameOfCorruptRecord) convertStream(inputStream, shouldDropHeader, tokenizer) { tokens => - parser.convert(tokens) + safeParser.parse(tokens) }.flatten } @@ -368,7 +288,8 @@ private[csv] object UnivocityParser { def parseIterator( lines: Iterator[String], shouldDropHeader: Boolean, - parser: UnivocityParser): Iterator[InternalRow] = { + parser: UnivocityParser, + schema: StructType): Iterator[InternalRow] = { val options = parser.options val linesWithoutHeader = if (shouldDropHeader) { @@ -381,6 +302,12 @@ private[csv] object UnivocityParser { val filteredLines: Iterator[String] = CSVUtils.filterCommentAndEmpty(linesWithoutHeader, options) - filteredLines.flatMap(line => parser.parse(line)) + + val safeParser = new FailureSafeParser[String]( + input => Seq(parser.parse(input)), + parser.options.parseMode, + schema, + parser.options.columnNameOfCorruptRecord) + filteredLines.flatMap(safeParser.parse) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala index 84f026620d90..51e952c12202 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.datasources.json +import java.io.InputStream + import com.fasterxml.jackson.core.{JsonFactory, JsonParser} import com.google.common.io.ByteStreams import org.apache.hadoop.conf.Configuration @@ -31,6 +33,7 @@ import org.apache.spark.rdd.{BinaryFileRDD, RDD} import org.apache.spark.sql.{AnalysisException, Dataset, Encoders, SparkSession} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JSONOptions} +import org.apache.spark.sql.catalyst.util.FailureSafeParser import org.apache.spark.sql.execution.datasources.{CodecStreams, DataSource, HadoopFileLinesReader, PartitionedFile} import org.apache.spark.sql.execution.datasources.text.TextFileFormat import org.apache.spark.sql.types.StructType @@ -49,7 +52,8 @@ abstract class JsonDataSource extends Serializable { def readFile( conf: Configuration, file: PartitionedFile, - parser: JacksonParser): Iterator[InternalRow] + parser: JacksonParser, + schema: StructType): Iterator[InternalRow] final def inferSchema( sparkSession: SparkSession, @@ -127,10 +131,16 @@ object TextInputJsonDataSource extends JsonDataSource { override def readFile( conf: Configuration, file: PartitionedFile, - parser: JacksonParser): Iterator[InternalRow] = { + parser: JacksonParser, + schema: StructType): Iterator[InternalRow] = { val linesReader = new HadoopFileLinesReader(file, conf) Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => linesReader.close())) - linesReader.flatMap(parser.parse(_, CreateJacksonParser.text, textToUTF8String)) + val safeParser = new FailureSafeParser[Text]( + input => parser.parse(input, CreateJacksonParser.text, textToUTF8String), + parser.options.parseMode, + schema, + parser.options.columnNameOfCorruptRecord) + linesReader.flatMap(safeParser.parse) } private def textToUTF8String(value: Text): UTF8String = { @@ -180,7 +190,8 @@ object WholeFileJsonDataSource extends JsonDataSource { override def readFile( conf: Configuration, file: PartitionedFile, - parser: JacksonParser): Iterator[InternalRow] = { + parser: JacksonParser, + schema: StructType): Iterator[InternalRow] = { def partitionedFileString(ignored: Any): UTF8String = { Utils.tryWithResource { CodecStreams.createInputStreamWithCloseResource(conf, file.filePath) @@ -189,9 +200,13 @@ object WholeFileJsonDataSource extends JsonDataSource { } } - parser.parse( - CodecStreams.createInputStreamWithCloseResource(conf, file.filePath), - CreateJacksonParser.inputStream, - partitionedFileString).toIterator + val safeParser = new FailureSafeParser[InputStream]( + input => parser.parse(input, CreateJacksonParser.inputStream, partitionedFileString), + parser.options.parseMode, + schema, + parser.options.columnNameOfCorruptRecord) + + safeParser.parse( + CodecStreams.createInputStreamWithCloseResource(conf, file.filePath)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala index a9dd91eba6f7..53d62d88b04c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala @@ -102,6 +102,8 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister { sparkSession.sessionState.conf.sessionLocalTimeZone, sparkSession.sessionState.conf.columnNameOfCorruptRecord) + val actualSchema = + StructType(requiredSchema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord)) // Check a field requirement for corrupt records here to throw an exception in a driver side dataSchema.getFieldIndex(parsedOptions.columnNameOfCorruptRecord).foreach { corruptFieldIndex => val f = dataSchema(corruptFieldIndex) @@ -112,11 +114,12 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister { } (file: PartitionedFile) => { - val parser = new JacksonParser(requiredSchema, parsedOptions) + val parser = new JacksonParser(actualSchema, parsedOptions) JsonDataSource(parsedOptions).readFile( broadcastedHadoopConf.value.value, file, - parser) + parser, + requiredSchema) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index 95dfdf5b298e..598babfe0e7a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -293,7 +293,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { .load(testFile(carsFile)).collect() } - assert(exception.getMessage.contains("Malformed line in FAILFAST mode: 2015,Chevy,Volt")) + assert(exception.getMessage.contains("Malformed CSV record")) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 9b0efcbdaf5c..56fcf773f7dd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -1043,7 +1043,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { .json(corruptRecords) .collect() } - assert(exceptionOne.getMessage.contains("Malformed line in FAILFAST mode: {")) + assert(exceptionOne.getMessage.contains("JsonParseException")) val exceptionTwo = intercept[SparkException] { spark.read @@ -1052,7 +1052,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { .json(corruptRecords) .collect() } - assert(exceptionTwo.getMessage.contains("Malformed line in FAILFAST mode: {")) + assert(exceptionTwo.getMessage.contains("JsonParseException")) } test("Corrupt records: DROPMALFORMED mode") { @@ -1929,7 +1929,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { .json(path) .collect() } - assert(exceptionOne.getMessage.contains("Malformed line in FAILFAST mode")) + assert(exceptionOne.getMessage.contains("Failed to parse a value")) val exceptionTwo = intercept[SparkException] { spark.read @@ -1939,7 +1939,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { .json(path) .collect() } - assert(exceptionTwo.getMessage.contains("Malformed line in FAILFAST mode")) + assert(exceptionTwo.getMessage.contains("Failed to parse a value")) } } From d2dcd6792f4cea39e12945ad8c4cda5d8d034de4 Mon Sep 17 00:00:00 2001 From: Xiao Li Date: Mon, 20 Mar 2017 22:52:45 -0700 Subject: [PATCH 086/512] [SPARK-20024][SQL][TEST-MAVEN] SessionCatalog reset need to set the current database of ExternalCatalog ### What changes were proposed in this pull request? SessionCatalog API setCurrentDatabase does not set the current database of the underlying ExternalCatalog. Thus, weird errors could come in the test suites after we call reset. We need to fix it. So far, have not found the direct impact in the other code paths because we expect all the SessionCatalog APIs should always use the current database value we managed, unless some of code paths skip it. Thus, we fix it in the test-only function reset(). ### How was this patch tested? Multiple test case failures are observed in mvn and add a test case in SessionCatalogSuite. Author: Xiao Li Closes #17354 from gatorsmile/useDB. --- .../org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala | 1 + .../apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala | 2 -- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index 25aa8d3ba921..b134fd44a311 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -1175,6 +1175,7 @@ class SessionCatalog( */ def reset(): Unit = synchronized { setCurrentDatabase(DEFAULT_DATABASE) + externalCatalog.setCurrentDatabase(DEFAULT_DATABASE) listDatabases().filter(_ != DEFAULT_DATABASE).foreach { db => dropDatabase(db, ignoreIfNotExists = false, cascade = true) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala index bb87763e0bbb..fd9e5d6bb13e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala @@ -53,7 +53,6 @@ abstract class SessionCatalogSuite extends PlanTest { private def withBasicCatalog(f: SessionCatalog => Unit): Unit = { val catalog = new SessionCatalog(newBasicCatalog()) - catalog.createDatabase(newDb("default"), ignoreIfExists = true) try { f(catalog) } finally { @@ -76,7 +75,6 @@ abstract class SessionCatalogSuite extends PlanTest { test("basic create and list databases") { withEmptyCatalog { catalog => - catalog.createDatabase(newDb("default"), ignoreIfExists = true) assert(catalog.databaseExists("default")) assert(!catalog.databaseExists("testing")) assert(!catalog.databaseExists("testing2")) From 7620aed828d8baefc425b54684a83c81f1507b02 Mon Sep 17 00:00:00 2001 From: christopher snow Date: Tue, 21 Mar 2017 13:23:59 +0000 Subject: [PATCH 087/512] [SPARK-20011][ML][DOCS] Clarify documentation for ALS 'rank' parameter ## What changes were proposed in this pull request? API documentation and collaborative filtering documentation page changes to clarify inconsistent description of ALS rank parameter. - [DOCS] was previously: "rank is the number of latent factors in the model." - [API] was previously: "rank - number of features to use" This change describes rank in both places consistently as: - "Number of features to use (also referred to as the number of latent factors)" Author: Chris Snow Author: christopher snow Closes #17345 from snowch/SPARK-20011. --- docs/mllib-collaborative-filtering.md | 2 +- .../apache/spark/mllib/recommendation/ALS.scala | 16 ++++++++-------- python/pyspark/mllib/recommendation.py | 4 ++-- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/docs/mllib-collaborative-filtering.md b/docs/mllib-collaborative-filtering.md index 0f891a09a6e6..d1bb6d69f125 100644 --- a/docs/mllib-collaborative-filtering.md +++ b/docs/mllib-collaborative-filtering.md @@ -20,7 +20,7 @@ algorithm to learn these latent factors. The implementation in `spark.mllib` has following parameters: * *numBlocks* is the number of blocks used to parallelize computation (set to -1 to auto-configure). -* *rank* is the number of latent factors in the model. +* *rank* is the number of features to use (also referred to as the number of latent factors). * *iterations* is the number of iterations of ALS to run. ALS typically converges to a reasonable solution in 20 iterations or less. * *lambda* specifies the regularization parameter in ALS. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala index 76b1bc13b4b0..14288221b694 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala @@ -301,7 +301,7 @@ object ALS { * level of parallelism. * * @param ratings RDD of [[Rating]] objects with userID, productID, and rating - * @param rank number of features to use + * @param rank number of features to use (also referred to as the number of latent factors) * @param iterations number of iterations of ALS * @param lambda regularization parameter * @param blocks level of parallelism to split computation into @@ -326,7 +326,7 @@ object ALS { * level of parallelism. * * @param ratings RDD of [[Rating]] objects with userID, productID, and rating - * @param rank number of features to use + * @param rank number of features to use (also referred to as the number of latent factors) * @param iterations number of iterations of ALS * @param lambda regularization parameter * @param blocks level of parallelism to split computation into @@ -349,7 +349,7 @@ object ALS { * parallelism automatically based on the number of partitions in `ratings`. * * @param ratings RDD of [[Rating]] objects with userID, productID, and rating - * @param rank number of features to use + * @param rank number of features to use (also referred to as the number of latent factors) * @param iterations number of iterations of ALS * @param lambda regularization parameter */ @@ -366,7 +366,7 @@ object ALS { * parallelism automatically based on the number of partitions in `ratings`. * * @param ratings RDD of [[Rating]] objects with userID, productID, and rating - * @param rank number of features to use + * @param rank number of features to use (also referred to as the number of latent factors) * @param iterations number of iterations of ALS */ @Since("0.8.0") @@ -383,7 +383,7 @@ object ALS { * a level of parallelism given by `blocks`. * * @param ratings RDD of (userID, productID, rating) pairs - * @param rank number of features to use + * @param rank number of features to use (also referred to as the number of latent factors) * @param iterations number of iterations of ALS * @param lambda regularization parameter * @param blocks level of parallelism to split computation into @@ -410,7 +410,7 @@ object ALS { * iteratively with a configurable level of parallelism. * * @param ratings RDD of [[Rating]] objects with userID, productID, and rating - * @param rank number of features to use + * @param rank number of features to use (also referred to as the number of latent factors) * @param iterations number of iterations of ALS * @param lambda regularization parameter * @param blocks level of parallelism to split computation into @@ -436,7 +436,7 @@ object ALS { * partitions in `ratings`. * * @param ratings RDD of [[Rating]] objects with userID, productID, and rating - * @param rank number of features to use + * @param rank number of features to use (also referred to as the number of latent factors) * @param iterations number of iterations of ALS * @param lambda regularization parameter * @param alpha confidence parameter @@ -455,7 +455,7 @@ object ALS { * partitions in `ratings`. * * @param ratings RDD of [[Rating]] objects with userID, productID, and rating - * @param rank number of features to use + * @param rank number of features to use (also referred to as the number of latent factors) * @param iterations number of iterations of ALS */ @Since("0.8.1") diff --git a/python/pyspark/mllib/recommendation.py b/python/pyspark/mllib/recommendation.py index 732300ee9c2c..81182881352b 100644 --- a/python/pyspark/mllib/recommendation.py +++ b/python/pyspark/mllib/recommendation.py @@ -249,7 +249,7 @@ def train(cls, ratings, rank, iterations=5, lambda_=0.01, blocks=-1, nonnegative :param ratings: RDD of `Rating` or (userID, productID, rating) tuple. :param rank: - Rank of the feature matrices computed (number of features). + Number of features to use (also referred to as the number of latent factors). :param iterations: Number of iterations of ALS. (default: 5) @@ -287,7 +287,7 @@ def trainImplicit(cls, ratings, rank, iterations=5, lambda_=0.01, blocks=-1, alp :param ratings: RDD of `Rating` or (userID, productID, rating) tuple. :param rank: - Rank of the feature matrices computed (number of features). + Number of features to use (also referred to as the number of latent factors). :param iterations: Number of iterations of ALS. (default: 5) From 650d03cfc9a609a2c603f9ced452d03ec8429b0d Mon Sep 17 00:00:00 2001 From: "jianran.tfh" Date: Tue, 21 Mar 2017 15:15:19 +0000 Subject: [PATCH 088/512] [SPARK-19998][BLOCK MANAGER] Change the exception log to add RDD id of the related the block ## What changes were proposed in this pull request? "java.lang.Exception: Could not compute split, block $blockId not found" doesn't have the rdd id info, the "BlockManager: Removing RDD $id" has only the RDD id, so it couldn't find that the Exception's reason is the Removing; so it's better block not found Exception add RDD id info ## How was this patch tested? Existing tests Author: jianran.tfh Author: jianran Closes #17334 from jianran/SPARK-19998. --- core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala b/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala index d47b75544fdb..4e036c2ed49b 100644 --- a/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala @@ -47,7 +47,7 @@ class BlockRDD[T: ClassTag](sc: SparkContext, @transient val blockIds: Array[Blo blockManager.get[T](blockId) match { case Some(block) => block.data.asInstanceOf[Iterator[T]] case None => - throw new Exception("Could not compute split, block " + blockId + " not found") + throw new Exception(s"Could not compute split, block $blockId of RDD $id not found") } } From 14865d7ff78db5cf9a3e8626204c8e7ed059c353 Mon Sep 17 00:00:00 2001 From: wangzhenhua Date: Tue, 21 Mar 2017 08:44:09 -0700 Subject: [PATCH 089/512] [SPARK-17080][SQL][FOLLOWUP] Improve documentation, change buildJoin method structure and add a debug log ## What changes were proposed in this pull request? 1. Improve documentation for class `Cost` and `JoinReorderDP` and method `buildJoin()`. 2. Change code structure of `buildJoin()` to make the logic clearer. 3. Add a debug-level log to record information for join reordering, including time cost, the number of items and the number of plans in memo. ## How was this patch tested? Not related. Author: wangzhenhua Closes #17353 from wzhfy/reorderFollow. --- .../optimizer/CostBasedJoinReorder.scala | 109 +++++++++++------- .../apache/spark/sql/internal/SQLConf.scala | 1 + 2 files changed, 68 insertions(+), 42 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala index 521c468fe18a..fc37720809ba 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.optimizer import scala.collection.mutable +import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeSet, Expression, PredicateHelper} import org.apache.spark.sql.catalyst.plans.{Inner, InnerLike} import org.apache.spark.sql.catalyst.plans.logical.{BinaryNode, Join, LogicalPlan, Project} @@ -51,7 +52,7 @@ case class CostBasedJoinReorder(conf: SQLConf) extends Rule[LogicalPlan] with Pr } } - def reorder(plan: LogicalPlan, output: AttributeSet): LogicalPlan = { + private def reorder(plan: LogicalPlan, output: AttributeSet): LogicalPlan = { val (items, conditions) = extractInnerJoins(plan) // TODO: Compute the set of star-joins and use them in the join enumeration // algorithm to prune un-optimal plan choices. @@ -69,7 +70,7 @@ case class CostBasedJoinReorder(conf: SQLConf) extends Rule[LogicalPlan] with Pr } /** - * Extract consecutive inner joinable items and join conditions. + * Extracts items of consecutive inner joins and join conditions. * This method works for bushy trees and left/right deep trees. */ private def extractInnerJoins(plan: LogicalPlan): (Seq[LogicalPlan], Set[Expression]) = { @@ -119,18 +120,21 @@ case class CostBasedJoinReorder(conf: SQLConf) extends Rule[LogicalPlan] with Pr * When building m-way joins, we only keep the best plan (with the lowest cost) for the same set * of m items. E.g., for 3-way joins, we keep only the best plan for items {A, B, C} among * plans (A J B) J C, (A J C) J B and (B J C) J A. - * - * Thus the plans maintained for each level when reordering four items A, B, C, D are as follows: + * We also prune cartesian product candidates when building a new plan if there exists no join + * condition involving references from both left and right. This pruning strategy significantly + * reduces the search space. + * E.g., given A J B J C J D with join conditions A.k1 = B.k1 and B.k2 = C.k2 and C.k3 = D.k3, + * plans maintained for each level are as follows: * level 0: p({A}), p({B}), p({C}), p({D}) - * level 1: p({A, B}), p({A, C}), p({A, D}), p({B, C}), p({B, D}), p({C, D}) - * level 2: p({A, B, C}), p({A, B, D}), p({A, C, D}), p({B, C, D}) + * level 1: p({A, B}), p({B, C}), p({C, D}) + * level 2: p({A, B, C}), p({B, C, D}) * level 3: p({A, B, C, D}) * where p({A, B, C, D}) is the final output plan. * * For cost evaluation, since physical costs for operators are not available currently, we use * cardinalities and sizes to compute costs. */ -object JoinReorderDP extends PredicateHelper { +object JoinReorderDP extends PredicateHelper with Logging { def search( conf: SQLConf, @@ -138,6 +142,7 @@ object JoinReorderDP extends PredicateHelper { conditions: Set[Expression], topOutput: AttributeSet): LogicalPlan = { + val startTime = System.nanoTime() // Level i maintains all found plans for i + 1 items. // Create the initial plans: each plan is a single item with zero cost. val itemIndex = items.zipWithIndex @@ -152,6 +157,10 @@ object JoinReorderDP extends PredicateHelper { foundPlans += searchLevel(foundPlans, conf, conditions, topOutput) } + val durationInMs = (System.nanoTime() - startTime) / (1000 * 1000) + logDebug(s"Join reordering finished. Duration: $durationInMs ms, number of items: " + + s"${items.length}, number of plans in memo: ${foundPlans.map(_.size).sum}") + // The last level must have one and only one plan, because all items are joinable. assert(foundPlans.size == items.length && foundPlans.last.size == 1) foundPlans.last.head._2.plan @@ -183,18 +192,15 @@ object JoinReorderDP extends PredicateHelper { } otherSideCandidates.foreach { otherSidePlan => - // Should not join two overlapping item sets. - if (oneSidePlan.itemIds.intersect(otherSidePlan.itemIds).isEmpty) { - val joinPlan = buildJoin(oneSidePlan, otherSidePlan, conf, conditions, topOutput) - if (joinPlan.isDefined) { - val newJoinPlan = joinPlan.get + buildJoin(oneSidePlan, otherSidePlan, conf, conditions, topOutput) match { + case Some(newJoinPlan) => // Check if it's the first plan for the item set, or it's a better plan than // the existing one due to lower cost. val existingPlan = nextLevel.get(newJoinPlan.itemIds) if (existingPlan.isEmpty || newJoinPlan.betterThan(existingPlan.get, conf)) { nextLevel.update(newJoinPlan.itemIds, newJoinPlan) } - } + case None => } } } @@ -203,7 +209,17 @@ object JoinReorderDP extends PredicateHelper { nextLevel.toMap } - /** Build a new join node. */ + /** + * Builds a new JoinPlan when both conditions hold: + * - the sets of items contained in left and right sides do not overlap. + * - there exists at least one join condition involving references from both sides. + * @param oneJoinPlan One side JoinPlan for building a new JoinPlan. + * @param otherJoinPlan The other side JoinPlan for building a new join node. + * @param conf SQLConf for statistics computation. + * @param conditions The overall set of join conditions. + * @param topOutput The output attributes of the final plan. + * @return Builds and returns a new JoinPlan if both conditions hold. Otherwise, returns None. + */ private def buildJoin( oneJoinPlan: JoinPlan, otherJoinPlan: JoinPlan, @@ -211,6 +227,11 @@ object JoinReorderDP extends PredicateHelper { conditions: Set[Expression], topOutput: AttributeSet): Option[JoinPlan] = { + if (oneJoinPlan.itemIds.intersect(otherJoinPlan.itemIds).nonEmpty) { + // Should not join two overlapping item sets. + return None + } + val onePlan = oneJoinPlan.plan val otherPlan = otherJoinPlan.plan val joinConds = conditions @@ -220,33 +241,33 @@ object JoinReorderDP extends PredicateHelper { if (joinConds.isEmpty) { // Cartesian product is very expensive, so we exclude them from candidate plans. // This also significantly reduces the search space. - None + return None + } + + // Put the deeper side on the left, tend to build a left-deep tree. + val (left, right) = if (oneJoinPlan.itemIds.size >= otherJoinPlan.itemIds.size) { + (onePlan, otherPlan) } else { - // Put the deeper side on the left, tend to build a left-deep tree. - val (left, right) = if (oneJoinPlan.itemIds.size >= otherJoinPlan.itemIds.size) { - (onePlan, otherPlan) + (otherPlan, onePlan) + } + val newJoin = Join(left, right, Inner, joinConds.reduceOption(And)) + val collectedJoinConds = joinConds ++ oneJoinPlan.joinConds ++ otherJoinPlan.joinConds + val remainingConds = conditions -- collectedJoinConds + val neededAttr = AttributeSet(remainingConds.flatMap(_.references)) ++ topOutput + val neededFromNewJoin = newJoin.outputSet.filter(neededAttr.contains) + val newPlan = + if ((newJoin.outputSet -- neededFromNewJoin).nonEmpty) { + Project(neededFromNewJoin.toSeq, newJoin) } else { - (otherPlan, onePlan) + newJoin } - val newJoin = Join(left, right, Inner, joinConds.reduceOption(And)) - val collectedJoinConds = joinConds ++ oneJoinPlan.joinConds ++ otherJoinPlan.joinConds - val remainingConds = conditions -- collectedJoinConds - val neededAttr = AttributeSet(remainingConds.flatMap(_.references)) ++ topOutput - val neededFromNewJoin = newJoin.outputSet.filter(neededAttr.contains) - val newPlan = - if ((newJoin.outputSet -- neededFromNewJoin).nonEmpty) { - Project(neededFromNewJoin.toSeq, newJoin) - } else { - newJoin - } - val itemIds = oneJoinPlan.itemIds.union(otherJoinPlan.itemIds) - // Now the root node of onePlan/otherPlan becomes an intermediate join (if it's a non-leaf - // item), so the cost of the new join should also include its own cost. - val newPlanCost = oneJoinPlan.planCost + oneJoinPlan.rootCost(conf) + - otherJoinPlan.planCost + otherJoinPlan.rootCost(conf) - Some(JoinPlan(itemIds, newPlan, collectedJoinConds, newPlanCost)) - } + val itemIds = oneJoinPlan.itemIds.union(otherJoinPlan.itemIds) + // Now the root node of onePlan/otherPlan becomes an intermediate join (if it's a non-leaf + // item), so the cost of the new join should also include its own cost. + val newPlanCost = oneJoinPlan.planCost + oneJoinPlan.rootCost(conf) + + otherJoinPlan.planCost + otherJoinPlan.rootCost(conf) + Some(JoinPlan(itemIds, newPlan, collectedJoinConds, newPlanCost)) } /** Map[set of item ids, join plan for these items] */ @@ -278,10 +299,10 @@ object JoinReorderDP extends PredicateHelper { } def betterThan(other: JoinPlan, conf: SQLConf): Boolean = { - if (other.planCost.rows == 0 || other.planCost.size == 0) { + if (other.planCost.card == 0 || other.planCost.size == 0) { false } else { - val relativeRows = BigDecimal(this.planCost.rows) / BigDecimal(other.planCost.rows) + val relativeRows = BigDecimal(this.planCost.card) / BigDecimal(other.planCost.card) val relativeSize = BigDecimal(this.planCost.size) / BigDecimal(other.planCost.size) relativeRows * conf.joinReorderCardWeight + relativeSize * (1 - conf.joinReorderCardWeight) < 1 @@ -290,7 +311,11 @@ object JoinReorderDP extends PredicateHelper { } } -/** This class defines the cost model. */ -case class Cost(rows: BigInt, size: BigInt) { - def +(other: Cost): Cost = Cost(this.rows + other.rows, this.size + other.size) +/** + * This class defines the cost model for a plan. + * @param card Cardinality (number of rows). + * @param size Size in bytes. + */ +case class Cost(card: BigInt, size: BigInt) { + def +(other: Cost): Cost = Cost(this.card + other.card, this.size + other.size) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index b6e0b8ccbeed..d5006c16469b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -708,6 +708,7 @@ object SQLConf { buildConf("spark.sql.cbo.joinReorder.dp.threshold") .doc("The maximum number of joined nodes allowed in the dynamic programming algorithm.") .intConf + .checkValue(number => number > 0, "The maximum number must be a positive integer.") .createWithDefault(12) val JOIN_REORDER_CARD_WEIGHT = From 63f077fbe50b4094340e9915db41d7dbdba52975 Mon Sep 17 00:00:00 2001 From: Zheng RuiFeng Date: Tue, 21 Mar 2017 08:45:59 -0700 Subject: [PATCH 090/512] [SPARK-20041][DOC] Update docs for NaN handling in approxQuantile ## What changes were proposed in this pull request? Update docs for NaN handling in approxQuantile. ## How was this patch tested? existing tests. Author: Zheng RuiFeng Closes #17369 from zhengruifeng/doc_quantiles_nan. --- R/pkg/R/stats.R | 3 ++- .../org/apache/spark/ml/feature/QuantileDiscretizer.scala | 4 ++-- python/pyspark/sql/dataframe.py | 3 ++- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/R/pkg/R/stats.R b/R/pkg/R/stats.R index 8d1d165052f7..d78a10893f92 100644 --- a/R/pkg/R/stats.R +++ b/R/pkg/R/stats.R @@ -149,7 +149,8 @@ setMethod("freqItems", signature(x = "SparkDataFrame", cols = "character"), #' This method implements a variation of the Greenwald-Khanna algorithm (with some speed #' optimizations). The algorithm was first present in [[http://dx.doi.org/10.1145/375663.375670 #' Space-efficient Online Computation of Quantile Summaries]] by Greenwald and Khanna. -#' Note that rows containing any NA values will be removed before calculation. +#' Note that NA values will be ignored in numerical columns before calculation. For +#' columns only containing NA values, an empty list is returned. #' #' @param x A SparkDataFrame. #' @param cols A single column name, or a list of names for multiple columns. diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala index 80c7f55e26b8..feceeba866df 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala @@ -93,8 +93,8 @@ private[feature] trait QuantileDiscretizerBase extends Params * are too few distinct values of the input to create enough distinct quantiles. * * NaN handling: - * NaN values will be removed from the column during `QuantileDiscretizer` fitting. This will - * produce a `Bucketizer` model for making predictions. During the transformation, + * null and NaN values will be ignored from the column during `QuantileDiscretizer` fitting. This + * will produce a `Bucketizer` model for making predictions. During the transformation, * `Bucketizer` will raise an error when it finds NaN values in the dataset, but the user can * also choose to either keep or remove NaN values within the dataset by setting `handleInvalid`. * If the user chooses to keep NaN values, they will be handled specially and placed into their own diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index bb6df2268209..a24512f53c52 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1384,7 +1384,8 @@ def approxQuantile(self, col, probabilities, relativeError): Space-efficient Online Computation of Quantile Summaries]] by Greenwald and Khanna. - Note that rows containing any null values will be removed before calculation. + Note that null values will be ignored in numerical columns before calculation. + For columns only containing null values, an empty list is returned. :param col: str, list. Can be a single column name, or a list of names for multiple columns. From 4c0ff5f58565f811b65f1a11b6121da007bcbd5f Mon Sep 17 00:00:00 2001 From: Xin Wu Date: Tue, 21 Mar 2017 08:49:54 -0700 Subject: [PATCH 091/512] [SPARK-19261][SQL] Alter add columns for Hive serde and some datasource tables ## What changes were proposed in this pull request? Support` ALTER TABLE ADD COLUMNS (...) `syntax for Hive serde and some datasource tables. In this PR, we consider a few aspects: 1. View is not supported for `ALTER ADD COLUMNS` 2. Since tables created in SparkSQL with Hive DDL syntax will populate table properties with schema information, we need make sure the consistency of the schema before and after ALTER operation in order for future use. 3. For embedded-schema type of format, such as `parquet`, we need to make sure that the predicate on the newly-added columns can be evaluated properly, or pushed down properly. In case of the data file does not have the columns for the newly-added columns, such predicates should return as if the column values are NULLs. 4. For datasource table, this feature does not support the following: 4.1 TEXT format, since there is only one default column `value` is inferred for text format data. 4.2 ORC format, since SparkSQL native ORC reader does not support the difference between user-specified-schema and inferred schema from ORC files. 4.3 Third party datasource types that implements RelationProvider, including the built-in JDBC format, since different implementations by the vendors may have different ways to dealing with schema. 4.4 Other datasource types, such as `parquet`, `json`, `csv`, `hive` are supported. 5. Column names being added can not be duplicate of any existing data column or partition column names. Case sensitivity is taken into consideration according to the sql configuration. 6. This feature also supports In-Memory catalog, while Hive support is turned off. ## How was this patch tested? Add new test cases Author: Xin Wu Closes #16626 from xwu0226/alter_add_columns. --- .../spark/sql/catalyst/parser/SqlBase.g4 | 3 +- .../sql/catalyst/catalog/SessionCatalog.scala | 56 ++++++++ .../catalog/SessionCatalogSuite.scala | 29 +++++ .../spark/sql/execution/SparkSqlParser.scala | 16 +++ .../spark/sql/execution/command/tables.scala | 76 ++++++++++- .../execution/command/DDLCommandSuite.scala | 8 +- .../sql/execution/command/DDLSuite.scala | 122 ++++++++++++++++++ .../sql/hive/execution/HiveDDLSuite.scala | 100 +++++++++++++- 8 files changed, 400 insertions(+), 10 deletions(-) diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index cc3b8fd3b468..c4a590ec6916 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -85,6 +85,8 @@ statement LIKE source=tableIdentifier locationSpec? #createTableLike | ANALYZE TABLE tableIdentifier partitionSpec? COMPUTE STATISTICS (identifier | FOR COLUMNS identifierSeq)? #analyze + | ALTER TABLE tableIdentifier + ADD COLUMNS '(' columns=colTypeList ')' #addTableColumns | ALTER (TABLE | VIEW) from=tableIdentifier RENAME TO to=tableIdentifier #renameTable | ALTER (TABLE | VIEW) tableIdentifier @@ -198,7 +200,6 @@ unsupportedHiveNativeCommands | kw1=ALTER kw2=TABLE tableIdentifier partitionSpec? kw3=COMPACT | kw1=ALTER kw2=TABLE tableIdentifier partitionSpec? kw3=CONCATENATE | kw1=ALTER kw2=TABLE tableIdentifier partitionSpec? kw3=SET kw4=FILEFORMAT - | kw1=ALTER kw2=TABLE tableIdentifier partitionSpec? kw3=ADD kw4=COLUMNS | kw1=ALTER kw2=TABLE tableIdentifier partitionSpec? kw3=REPLACE kw4=COLUMNS | kw1=START kw2=TRANSACTION | kw1=COMMIT diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index b134fd44a311..a469d1245164 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -35,6 +35,7 @@ import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionInfo} import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParserInterface} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias, View} import org.apache.spark.sql.catalyst.util.StringUtils +import org.apache.spark.sql.types.{StructField, StructType} object SessionCatalog { val DEFAULT_DATABASE = "default" @@ -161,6 +162,20 @@ class SessionCatalog( throw new TableAlreadyExistsException(db = db, table = name.table) } } + + private def checkDuplication(fields: Seq[StructField]): Unit = { + val columnNames = if (conf.caseSensitiveAnalysis) { + fields.map(_.name) + } else { + fields.map(_.name.toLowerCase) + } + if (columnNames.distinct.length != columnNames.length) { + val duplicateColumns = columnNames.groupBy(identity).collect { + case (x, ys) if ys.length > 1 => x + } + throw new AnalysisException(s"Found duplicate column(s): ${duplicateColumns.mkString(", ")}") + } + } // ---------------------------------------------------------------------------- // Databases // ---------------------------------------------------------------------------- @@ -295,6 +310,47 @@ class SessionCatalog( externalCatalog.alterTable(newTableDefinition) } + /** + * Alter the schema of a table identified by the provided table identifier. The new schema + * should still contain the existing bucket columns and partition columns used by the table. This + * method will also update any Spark SQL-related parameters stored as Hive table properties (such + * as the schema itself). + * + * @param identifier TableIdentifier + * @param newSchema Updated schema to be used for the table (must contain existing partition and + * bucket columns, and partition columns need to be at the end) + */ + def alterTableSchema( + identifier: TableIdentifier, + newSchema: StructType): Unit = { + val db = formatDatabaseName(identifier.database.getOrElse(getCurrentDatabase)) + val table = formatTableName(identifier.table) + val tableIdentifier = TableIdentifier(table, Some(db)) + requireDbExists(db) + requireTableExists(tableIdentifier) + checkDuplication(newSchema) + + val catalogTable = externalCatalog.getTable(db, table) + val oldSchema = catalogTable.schema + + // not supporting dropping columns yet + val nonExistentColumnNames = oldSchema.map(_.name).filterNot(columnNameResolved(newSchema, _)) + if (nonExistentColumnNames.nonEmpty) { + throw new AnalysisException( + s""" + |Some existing schema fields (${nonExistentColumnNames.mkString("[", ",", "]")}) are + |not present in the new schema. We don't support dropping columns yet. + """.stripMargin) + } + + // assuming the newSchema has all partition columns at the end as required + externalCatalog.alterTableSchema(db, table, newSchema) + } + + private def columnNameResolved(schema: StructType, colName: String): Boolean = { + schema.fields.map(_.name).exists(conf.resolver(_, colName)) + } + /** * Return whether a table/view with the specified name exists. If no database is specified, check * with current database. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala index fd9e5d6bb13e..ca4ce1c11707 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{Range, SubqueryAlias, View} +import org.apache.spark.sql.types._ class InMemorySessionCatalogSuite extends SessionCatalogSuite { protected val utils = new CatalogTestUtils { @@ -448,6 +449,34 @@ abstract class SessionCatalogSuite extends PlanTest { } } + test("alter table add columns") { + withBasicCatalog { sessionCatalog => + sessionCatalog.createTable(newTable("t1", "default"), ignoreIfExists = false) + val oldTab = sessionCatalog.externalCatalog.getTable("default", "t1") + sessionCatalog.alterTableSchema( + TableIdentifier("t1", Some("default")), + StructType(oldTab.dataSchema.add("c3", IntegerType) ++ oldTab.partitionSchema)) + + val newTab = sessionCatalog.externalCatalog.getTable("default", "t1") + // construct the expected table schema + val expectedTableSchema = StructType(oldTab.dataSchema.fields ++ + Seq(StructField("c3", IntegerType)) ++ oldTab.partitionSchema) + assert(newTab.schema == expectedTableSchema) + } + } + + test("alter table drop columns") { + withBasicCatalog { sessionCatalog => + sessionCatalog.createTable(newTable("t1", "default"), ignoreIfExists = false) + val oldTab = sessionCatalog.externalCatalog.getTable("default", "t1") + val e = intercept[AnalysisException] { + sessionCatalog.alterTableSchema( + TableIdentifier("t1", Some("default")), StructType(oldTab.schema.drop(1))) + }.getMessage + assert(e.contains("We don't support dropping columns yet.")) + } + } + test("get table") { withBasicCatalog { catalog => assert(catalog.getTableMetadata(TableIdentifier("tbl1", Some("db2"))) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index abea7a3bcf14..d4f23f9dd518 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -741,6 +741,22 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { ctx.VIEW != null) } + /** + * Create a [[AlterTableAddColumnsCommand]] command. + * + * For example: + * {{{ + * ALTER TABLE table1 + * ADD COLUMNS (col_name data_type [COMMENT col_comment], ...); + * }}} + */ + override def visitAddTableColumns(ctx: AddTableColumnsContext): LogicalPlan = withOrigin(ctx) { + AlterTableAddColumnsCommand( + visitTableIdentifier(ctx.tableIdentifier), + visitColTypeList(ctx.columns) + ) + } + /** * Create an [[AlterTableSetPropertiesCommand]] command. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala index beb3dcafd64f..93307fc88356 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala @@ -37,7 +37,10 @@ import org.apache.spark.sql.catalyst.catalog.CatalogTableType._ import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} import org.apache.spark.sql.catalyst.util.quoteIdentifier -import org.apache.spark.sql.execution.datasources.PartitioningUtils +import org.apache.spark.sql.execution.datasources.{DataSource, FileFormat, PartitioningUtils} +import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat +import org.apache.spark.sql.execution.datasources.json.JsonFileFormat +import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -174,6 +177,77 @@ case class AlterTableRenameCommand( } +/** + * A command that add columns to a table + * The syntax of using this command in SQL is: + * {{{ + * ALTER TABLE table_identifier + * ADD COLUMNS (col_name data_type [COMMENT col_comment], ...); + * }}} +*/ +case class AlterTableAddColumnsCommand( + table: TableIdentifier, + columns: Seq[StructField]) extends RunnableCommand { + override def run(sparkSession: SparkSession): Seq[Row] = { + val catalog = sparkSession.sessionState.catalog + val catalogTable = verifyAlterTableAddColumn(catalog, table) + + try { + sparkSession.catalog.uncacheTable(table.quotedString) + } catch { + case NonFatal(e) => + log.warn(s"Exception when attempting to uncache table ${table.quotedString}", e) + } + catalog.refreshTable(table) + + // make sure any partition columns are at the end of the fields + val reorderedSchema = catalogTable.dataSchema ++ columns ++ catalogTable.partitionSchema + catalog.alterTableSchema( + table, catalogTable.schema.copy(fields = reorderedSchema.toArray)) + + Seq.empty[Row] + } + + /** + * ALTER TABLE ADD COLUMNS command does not support temporary view/table, + * view, or datasource table with text, orc formats or external provider. + * For datasource table, it currently only supports parquet, json, csv. + */ + private def verifyAlterTableAddColumn( + catalog: SessionCatalog, + table: TableIdentifier): CatalogTable = { + val catalogTable = catalog.getTempViewOrPermanentTableMetadata(table) + + if (catalogTable.tableType == CatalogTableType.VIEW) { + throw new AnalysisException( + s""" + |ALTER ADD COLUMNS does not support views. + |You must drop and re-create the views for adding the new columns. Views: $table + """.stripMargin) + } + + if (DDLUtils.isDatasourceTable(catalogTable)) { + DataSource.lookupDataSource(catalogTable.provider.get).newInstance() match { + // For datasource table, this command can only support the following File format. + // TextFileFormat only default to one column "value" + // OrcFileFormat can not handle difference between user-specified schema and + // inferred schema yet. TODO, once this issue is resolved , we can add Orc back. + // Hive type is already considered as hive serde table, so the logic will not + // come in here. + case _: JsonFileFormat | _: CSVFileFormat | _: ParquetFileFormat => + case s => + throw new AnalysisException( + s""" + |ALTER ADD COLUMNS does not support datasource table with type $s. + |You must drop and re-create the table for adding the new columns. Tables: $table + """.stripMargin) + } + } + catalogTable + } +} + + /** * A command that loads data into a Hive table. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala index 4b73b078da38..13202a57851e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala @@ -780,13 +780,7 @@ class DDLCommandSuite extends PlanTest { assertUnsupported("ALTER TABLE table_name SKEWED BY (key) ON (1,5,6) STORED AS DIRECTORIES") } - test("alter table: add/replace columns (not allowed)") { - assertUnsupported( - """ - |ALTER TABLE table_name PARTITION (dt='2008-08-08', country='us') - |ADD COLUMNS (new_col1 INT COMMENT 'test_comment', new_col2 LONG - |COMMENT 'test_comment2') CASCADE - """.stripMargin) + test("alter table: replace columns (not allowed)") { assertUnsupported( """ |ALTER TABLE table_name REPLACE COLUMNS (new_col1 INT diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index 235c6bf6ad59..648b1798c66e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -2185,4 +2185,126 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } } } + + val supportedNativeFileFormatsForAlterTableAddColumns = Seq("parquet", "json", "csv") + + supportedNativeFileFormatsForAlterTableAddColumns.foreach { provider => + test(s"alter datasource table add columns - $provider") { + withTable("t1") { + sql(s"CREATE TABLE t1 (c1 int) USING $provider") + sql("INSERT INTO t1 VALUES (1)") + sql("ALTER TABLE t1 ADD COLUMNS (c2 int)") + checkAnswer( + spark.table("t1"), + Seq(Row(1, null)) + ) + checkAnswer( + sql("SELECT * FROM t1 WHERE c2 is null"), + Seq(Row(1, null)) + ) + + sql("INSERT INTO t1 VALUES (3, 2)") + checkAnswer( + sql("SELECT * FROM t1 WHERE c2 = 2"), + Seq(Row(3, 2)) + ) + } + } + } + + supportedNativeFileFormatsForAlterTableAddColumns.foreach { provider => + test(s"alter datasource table add columns - partitioned - $provider") { + withTable("t1") { + sql(s"CREATE TABLE t1 (c1 int, c2 int) USING $provider PARTITIONED BY (c2)") + sql("INSERT INTO t1 PARTITION(c2 = 2) VALUES (1)") + sql("ALTER TABLE t1 ADD COLUMNS (c3 int)") + checkAnswer( + spark.table("t1"), + Seq(Row(1, null, 2)) + ) + checkAnswer( + sql("SELECT * FROM t1 WHERE c3 is null"), + Seq(Row(1, null, 2)) + ) + sql("INSERT INTO t1 PARTITION(c2 =1) VALUES (2, 3)") + checkAnswer( + sql("SELECT * FROM t1 WHERE c3 = 3"), + Seq(Row(2, 3, 1)) + ) + checkAnswer( + sql("SELECT * FROM t1 WHERE c2 = 1"), + Seq(Row(2, 3, 1)) + ) + } + } + } + + test("alter datasource table add columns - text format not supported") { + withTable("t1") { + sql("CREATE TABLE t1 (c1 int) USING text") + val e = intercept[AnalysisException] { + sql("ALTER TABLE t1 ADD COLUMNS (c2 int)") + }.getMessage + assert(e.contains("ALTER ADD COLUMNS does not support datasource table with type")) + } + } + + test("alter table add columns -- not support temp view") { + withTempView("tmp_v") { + sql("CREATE TEMPORARY VIEW tmp_v AS SELECT 1 AS c1, 2 AS c2") + val e = intercept[AnalysisException] { + sql("ALTER TABLE tmp_v ADD COLUMNS (c3 INT)") + } + assert(e.message.contains("ALTER ADD COLUMNS does not support views")) + } + } + + test("alter table add columns -- not support view") { + withView("v1") { + sql("CREATE VIEW v1 AS SELECT 1 AS c1, 2 AS c2") + val e = intercept[AnalysisException] { + sql("ALTER TABLE v1 ADD COLUMNS (c3 INT)") + } + assert(e.message.contains("ALTER ADD COLUMNS does not support views")) + } + } + + test("alter table add columns with existing column name") { + withTable("t1") { + sql("CREATE TABLE t1 (c1 int) USING PARQUET") + val e = intercept[AnalysisException] { + sql("ALTER TABLE t1 ADD COLUMNS (c1 string)") + }.getMessage + assert(e.contains("Found duplicate column(s)")) + } + } + + Seq(true, false).foreach { caseSensitive => + test(s"alter table add columns with existing column name - caseSensitive $caseSensitive") { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> s"$caseSensitive") { + withTable("t1") { + sql("CREATE TABLE t1 (c1 int) USING PARQUET") + if (!caseSensitive) { + val e = intercept[AnalysisException] { + sql("ALTER TABLE t1 ADD COLUMNS (C1 string)") + }.getMessage + assert(e.contains("Found duplicate column(s)")) + } else { + if (isUsingHiveMetastore) { + // hive catalog will still complains that c1 is duplicate column name because hive + // identifiers are case insensitive. + val e = intercept[AnalysisException] { + sql("ALTER TABLE t1 ADD COLUMNS (C1 string)") + }.getMessage + assert(e.contains("HiveException")) + } else { + sql("ALTER TABLE t1 ADD COLUMNS (C1 string)") + assert(spark.table("t1").schema + .equals(new StructType().add("c1", IntegerType).add("C1", StringType))) + } + } + } + } + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala index d752c415c1ed..04bc79d43032 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -35,7 +35,7 @@ import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION import org.apache.spark.sql.test.SQLTestUtils -import org.apache.spark.sql.types.{MetadataBuilder, StructType} +import org.apache.spark.sql.types._ // TODO(gatorsmile): combine HiveCatalogedDDLSuite and HiveDDLSuite class HiveCatalogedDDLSuite extends DDLSuite with TestHiveSingleton with BeforeAndAfterEach { @@ -112,6 +112,7 @@ class HiveCatalogedDDLSuite extends DDLSuite with TestHiveSingleton with BeforeA class HiveDDLSuite extends QueryTest with SQLTestUtils with TestHiveSingleton with BeforeAndAfterEach { import testImplicits._ + val hiveFormats = Seq("PARQUET", "ORC", "TEXTFILE", "SEQUENCEFILE", "RCFILE", "AVRO") override def afterEach(): Unit = { try { @@ -1860,4 +1861,101 @@ class HiveDDLSuite } } } + + hiveFormats.foreach { tableType => + test(s"alter hive serde table add columns -- partitioned - $tableType") { + withTable("tab") { + sql( + s""" + |CREATE TABLE tab (c1 int, c2 int) + |PARTITIONED BY (c3 int) STORED AS $tableType + """.stripMargin) + + sql("INSERT INTO tab PARTITION (c3=1) VALUES (1, 2)") + sql("ALTER TABLE tab ADD COLUMNS (c4 int)") + + checkAnswer( + sql("SELECT * FROM tab WHERE c3 = 1"), + Seq(Row(1, 2, null, 1)) + ) + assert(spark.table("tab").schema + .contains(StructField("c4", IntegerType))) + sql("INSERT INTO tab PARTITION (c3=2) VALUES (2, 3, 4)") + checkAnswer( + spark.table("tab"), + Seq(Row(1, 2, null, 1), Row(2, 3, 4, 2)) + ) + checkAnswer( + sql("SELECT * FROM tab WHERE c3 = 2 AND c4 IS NOT NULL"), + Seq(Row(2, 3, 4, 2)) + ) + + sql("ALTER TABLE tab ADD COLUMNS (c5 char(10))") + assert(spark.table("tab").schema.find(_.name == "c5") + .get.metadata.getString("HIVE_TYPE_STRING") == "char(10)") + } + } + } + + hiveFormats.foreach { tableType => + test(s"alter hive serde table add columns -- with predicate - $tableType ") { + withTable("tab") { + sql(s"CREATE TABLE tab (c1 int, c2 int) STORED AS $tableType") + sql("INSERT INTO tab VALUES (1, 2)") + sql("ALTER TABLE tab ADD COLUMNS (c4 int)") + checkAnswer( + sql("SELECT * FROM tab WHERE c4 IS NULL"), + Seq(Row(1, 2, null)) + ) + assert(spark.table("tab").schema + .contains(StructField("c4", IntegerType))) + sql("INSERT INTO tab VALUES (2, 3, 4)") + checkAnswer( + sql("SELECT * FROM tab WHERE c4 = 4 "), + Seq(Row(2, 3, 4)) + ) + checkAnswer( + spark.table("tab"), + Seq(Row(1, 2, null), Row(2, 3, 4)) + ) + } + } + } + + Seq(true, false).foreach { caseSensitive => + test(s"alter add columns with existing column name - caseSensitive $caseSensitive") { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> s"$caseSensitive") { + withTable("tab") { + sql("CREATE TABLE tab (c1 int) PARTITIONED BY (c2 int) STORED AS PARQUET") + if (!caseSensitive) { + // duplicating partitioning column name + val e1 = intercept[AnalysisException] { + sql("ALTER TABLE tab ADD COLUMNS (C2 string)") + }.getMessage + assert(e1.contains("Found duplicate column(s)")) + + // duplicating data column name + val e2 = intercept[AnalysisException] { + sql("ALTER TABLE tab ADD COLUMNS (C1 string)") + }.getMessage + assert(e2.contains("Found duplicate column(s)")) + } else { + // hive catalog will still complains that c1 is duplicate column name because hive + // identifiers are case insensitive. + val e1 = intercept[AnalysisException] { + sql("ALTER TABLE tab ADD COLUMNS (C2 string)") + }.getMessage + assert(e1.contains("HiveException")) + + // hive catalog will still complains that c1 is duplicate column name because hive + // identifiers are case insensitive. + val e2 = intercept[AnalysisException] { + sql("ALTER TABLE tab ADD COLUMNS (C1 string)") + }.getMessage + assert(e2.contains("HiveException")) + } + } + } + } + } } From ae4b91d1f5734b9d66f3b851b71b3c179f3cdd76 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Tue, 21 Mar 2017 11:01:25 -0700 Subject: [PATCH 092/512] [SPARK-20039][ML] rename ChiSquare to ChiSquareTest ## What changes were proposed in this pull request? I realized that since ChiSquare is in the package stat, it's pretty unclear if it's the hypothesis test, distribution, or what. This PR renames it to ChiSquareTest to clarify this. ## How was this patch tested? Existing unit tests Author: Joseph K. Bradley Closes #17368 from jkbradley/SPARK-20039. --- .../ml/stat/{ChiSquare.scala => ChiSquareTest.scala} | 2 +- .../{ChiSquareSuite.scala => ChiSquareTestSuite.scala} | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) rename mllib/src/main/scala/org/apache/spark/ml/stat/{ChiSquare.scala => ChiSquareTest.scala} (99%) rename mllib/src/test/scala/org/apache/spark/ml/stat/{ChiSquareSuite.scala => ChiSquareTestSuite.scala} (94%) diff --git a/mllib/src/main/scala/org/apache/spark/ml/stat/ChiSquare.scala b/mllib/src/main/scala/org/apache/spark/ml/stat/ChiSquareTest.scala similarity index 99% rename from mllib/src/main/scala/org/apache/spark/ml/stat/ChiSquare.scala rename to mllib/src/main/scala/org/apache/spark/ml/stat/ChiSquareTest.scala index c3865ce6a9e2..21eba9a49809 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/stat/ChiSquare.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/stat/ChiSquareTest.scala @@ -37,7 +37,7 @@ import org.apache.spark.sql.functions.col */ @Experimental @Since("2.2.0") -object ChiSquare { +object ChiSquareTest { /** Used to construct output schema of tests */ private case class ChiSquareResult( diff --git a/mllib/src/test/scala/org/apache/spark/ml/stat/ChiSquareSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/stat/ChiSquareTestSuite.scala similarity index 94% rename from mllib/src/test/scala/org/apache/spark/ml/stat/ChiSquareSuite.scala rename to mllib/src/test/scala/org/apache/spark/ml/stat/ChiSquareTestSuite.scala index b4bed82e4d00..2d6aad0808bc 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/stat/ChiSquareSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/stat/ChiSquareTestSuite.scala @@ -27,7 +27,7 @@ import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.stat.test.ChiSqTest import org.apache.spark.mllib.util.MLlibTestSparkContext -class ChiSquareSuite +class ChiSquareTestSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { import testImplicits._ @@ -45,7 +45,7 @@ class ChiSquareSuite LabeledPoint(1.0, Vectors.dense(3.5, 40.0))) for (numParts <- List(2, 4, 6, 8)) { val df = spark.createDataFrame(sc.parallelize(data, numParts)) - val chi = ChiSquare.test(df, "features", "label") + val chi = ChiSquareTest.test(df, "features", "label") val (pValues: Vector, degreesOfFreedom: Array[Int], statistics: Vector) = chi.select("pValues", "degreesOfFreedom", "statistics") .as[(Vector, Array[Int], Vector)].head() @@ -62,7 +62,7 @@ class ChiSquareSuite LabeledPoint(0.0, Vectors.sparse(numCols, Seq((100, 2.0)))), LabeledPoint(0.1, Vectors.sparse(numCols, Seq((200, 1.0))))) val df = spark.createDataFrame(sparseData) - val chi = ChiSquare.test(df, "features", "label") + val chi = ChiSquareTest.test(df, "features", "label") val (pValues: Vector, degreesOfFreedom: Array[Int], statistics: Vector) = chi.select("pValues", "degreesOfFreedom", "statistics") .as[(Vector, Array[Int], Vector)].head() @@ -83,7 +83,7 @@ class ChiSquareSuite withClue("ChiSquare should throw an exception when given a continuous-valued label") { intercept[SparkException] { val df = spark.createDataFrame(continuousLabel) - ChiSquare.test(df, "features", "label") + ChiSquareTest.test(df, "features", "label") } } val continuousFeature = Seq.fill(tooManyCategories)( @@ -91,7 +91,7 @@ class ChiSquareSuite withClue("ChiSquare should throw an exception when given continuous-valued features") { intercept[SparkException] { val df = spark.createDataFrame(continuousFeature) - ChiSquare.test(df, "features", "label") + ChiSquareTest.test(df, "features", "label") } } } From 7dbc162f12cc1a447c85a1a2c20d32ebb5cbeacf Mon Sep 17 00:00:00 2001 From: zhaorongsheng <334362872@qq.com> Date: Tue, 21 Mar 2017 11:30:55 -0700 Subject: [PATCH 093/512] [SPARK-20017][SQL] change the nullability of function 'StringToMap' from 'false' to 'true' ## What changes were proposed in this pull request? Change the nullability of function `StringToMap` from `false` to `true`. Author: zhaorongsheng <334362872@qq.com> Closes #17350 from zhaorongsheng/bug-fix_strToMap_NPE. --- .../sql/catalyst/expressions/complexTypeCreator.scala | 4 +++- .../spark/sql/catalyst/expressions/ComplexTypeSuite.scala | 7 +++++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 22277ad8d56e..b6675a84ece4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -390,6 +390,8 @@ case class CreateNamedStructUnsafe(children: Seq[Expression]) extends CreateName Examples: > SELECT _FUNC_('a:1,b:2,c:3', ',', ':'); map("a":"1","b":"2","c":"3") + > SELECT _FUNC_('a'); + map("a":null) """) // scalastyle:on line.size.limit case class StringToMap(text: Expression, pairDelim: Expression, keyValueDelim: Expression) @@ -407,7 +409,7 @@ case class StringToMap(text: Expression, pairDelim: Expression, keyValueDelim: E override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType, StringType) - override def dataType: DataType = MapType(StringType, StringType, valueContainsNull = false) + override def dataType: DataType = MapType(StringType, StringType) override def checkInputDataTypes(): TypeCheckResult = { if (Seq(pairDelim, keyValueDelim).exists(! _.foldable)) { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala index abe1d2b2c99e..5f8a8f44d48e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala @@ -251,6 +251,9 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { } test("StringToMap") { + val expectedDataType = MapType(StringType, StringType, valueContainsNull = true) + assert(new StringToMap("").dataType === expectedDataType) + val s0 = Literal("a:1,b:2,c:3") val m0 = Map("a" -> "1", "b" -> "2", "c" -> "3") checkEvaluation(new StringToMap(s0), m0) @@ -271,6 +274,10 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { val m4 = Map("a" -> "1", "b" -> "2", "c" -> "3") checkEvaluation(new StringToMap(s4, Literal("_")), m4) + val s5 = Literal("a") + val m5 = Map("a" -> null) + checkEvaluation(new StringToMap(s5), m5) + // arguments checking assert(new StringToMap(Literal("a:1,b:2,c:3")).checkInputDataTypes().isSuccess) assert(new StringToMap(Literal(null)).checkInputDataTypes().isFailure) From a8877bdbba6df105740f909bc87a13cdd4440757 Mon Sep 17 00:00:00 2001 From: Felix Cheung Date: Tue, 21 Mar 2017 14:24:41 -0700 Subject: [PATCH 094/512] [SPARK-19237][SPARKR][CORE] On Windows spark-submit should handle when java is not installed ## What changes were proposed in this pull request? When SparkR is installed as a R package there might not be any java runtime. If it is not there SparkR's `sparkR.session()` will block waiting for the connection timeout, hanging the R IDE/shell, without any notification or message. ## How was this patch tested? manually - [x] need to test on Windows Author: Felix Cheung Closes #16596 from felixcheung/rcheckjava. --- R/pkg/inst/tests/testthat/test_Windows.R | 1 + bin/spark-class2.cmd | 11 ++++++++++- 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/R/pkg/inst/tests/testthat/test_Windows.R b/R/pkg/inst/tests/testthat/test_Windows.R index e8d983426a67..1d777ddb286d 100644 --- a/R/pkg/inst/tests/testthat/test_Windows.R +++ b/R/pkg/inst/tests/testthat/test_Windows.R @@ -20,6 +20,7 @@ test_that("sparkJars tag in SparkContext", { if (.Platform$OS.type != "windows") { skip("This test is only for Windows, skipped") } + testOutput <- launchScript("ECHO", "a/b/c", wait = TRUE) abcPath <- testOutput[1] expect_equal(abcPath, "a\\b\\c") diff --git a/bin/spark-class2.cmd b/bin/spark-class2.cmd index 869c0b202f7f..9faa7d65f83e 100644 --- a/bin/spark-class2.cmd +++ b/bin/spark-class2.cmd @@ -50,7 +50,16 @@ if not "x%SPARK_PREPEND_CLASSES%"=="x" ( rem Figure out where java is. set RUNNER=java -if not "x%JAVA_HOME%"=="x" set RUNNER=%JAVA_HOME%\bin\java +if not "x%JAVA_HOME%"=="x" ( + set RUNNER="%JAVA_HOME%\bin\java" +) else ( + where /q "%RUNNER%" + if ERRORLEVEL 1 ( + echo Java not found and JAVA_HOME environment variable is not set. + echo Install Java and set JAVA_HOME to point to the Java installation directory. + exit /b 1 + ) +) rem The launcher library prints the command to be executed in a single line suitable for being rem executed by the batch interpreter. So read all the output of the launcher into a variable. From a04dcde8cb191e591a5f5d7a67a5371e31e7343c Mon Sep 17 00:00:00 2001 From: Will Manning Date: Wed, 22 Mar 2017 00:40:48 +0100 Subject: [PATCH 095/512] clarify array_contains function description ## What changes were proposed in this pull request? The description in the comment for array_contains is vague/incomplete (i.e., doesn't mention that it returns `null` if the array is `null`); this PR fixes that. ## How was this patch tested? No testing, since it merely changes a comment. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Will Manning Closes #17380 from lwwmanning/patch-1. --- sql/core/src/main/scala/org/apache/spark/sql/functions.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index a9f089c850d4..66bb8816a670 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -2896,7 +2896,7 @@ object functions { ////////////////////////////////////////////////////////////////////////////////////////////// /** - * Returns true if the array contains `value` + * Returns null if the array is null, true if the array contains `value`, and false otherwise. * @group collection_funcs * @since 1.5.0 */ From 9281a3d504d526440c1d445075e38a6d9142ac93 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Wed, 22 Mar 2017 08:41:46 +0800 Subject: [PATCH 096/512] [SPARK-19919][SQL] Defer throwing the exception for empty paths in CSV datasource into `DataSource` ## What changes were proposed in this pull request? This PR proposes to defer throwing the exception within `DataSource`. Currently, if other datasources fail to infer the schema, it returns `None` and then this is being validated in `DataSource` as below: ``` scala> spark.read.json("emptydir") org.apache.spark.sql.AnalysisException: Unable to infer schema for JSON. It must be specified manually.; ``` ``` scala> spark.read.orc("emptydir") org.apache.spark.sql.AnalysisException: Unable to infer schema for ORC. It must be specified manually.; ``` ``` scala> spark.read.parquet("emptydir") org.apache.spark.sql.AnalysisException: Unable to infer schema for Parquet. It must be specified manually.; ``` However, CSV it checks it within the datasource implementation and throws another exception message as below: ``` scala> spark.read.csv("emptydir") java.lang.IllegalArgumentException: requirement failed: Cannot infer schema from an empty set of files ``` We could remove this duplicated check and validate this in one place in the same way with the same message. ## How was this patch tested? Unit test in `CSVSuite` and manual test. Author: hyukjinkwon Closes #17256 from HyukjinKwon/SPARK-19919. --- .../datasources/csv/CSVDataSource.scala | 25 +++++++++++++------ .../datasources/csv/CSVFileFormat.scala | 4 +-- .../sql/test/DataFrameReaderWriterSuite.scala | 6 +++-- 3 files changed, 23 insertions(+), 12 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala index 63af18ec5b8e..83bdf6fe224b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala @@ -54,10 +54,21 @@ abstract class CSVDataSource extends Serializable { /** * Infers the schema from `inputPaths` files. */ - def infer( + final def inferSchema( sparkSession: SparkSession, inputPaths: Seq[FileStatus], - parsedOptions: CSVOptions): Option[StructType] + parsedOptions: CSVOptions): Option[StructType] = { + if (inputPaths.nonEmpty) { + Some(infer(sparkSession, inputPaths, parsedOptions)) + } else { + None + } + } + + protected def infer( + sparkSession: SparkSession, + inputPaths: Seq[FileStatus], + parsedOptions: CSVOptions): StructType /** * Generates a header from the given row which is null-safe and duplicate-safe. @@ -131,10 +142,10 @@ object TextInputCSVDataSource extends CSVDataSource { override def infer( sparkSession: SparkSession, inputPaths: Seq[FileStatus], - parsedOptions: CSVOptions): Option[StructType] = { + parsedOptions: CSVOptions): StructType = { val csv = createBaseDataset(sparkSession, inputPaths, parsedOptions) val maybeFirstLine = CSVUtils.filterCommentAndEmpty(csv, parsedOptions).take(1).headOption - Some(inferFromDataset(sparkSession, csv, maybeFirstLine, parsedOptions)) + inferFromDataset(sparkSession, csv, maybeFirstLine, parsedOptions) } /** @@ -203,7 +214,7 @@ object WholeFileCSVDataSource extends CSVDataSource { override def infer( sparkSession: SparkSession, inputPaths: Seq[FileStatus], - parsedOptions: CSVOptions): Option[StructType] = { + parsedOptions: CSVOptions): StructType = { val csv = createBaseRdd(sparkSession, inputPaths, parsedOptions) csv.flatMap { lines => UnivocityParser.tokenizeStream( @@ -222,10 +233,10 @@ object WholeFileCSVDataSource extends CSVDataSource { parsedOptions.headerFlag, new CsvParser(parsedOptions.asParserSettings)) } - Some(CSVInferSchema.infer(tokenRDD, header, parsedOptions)) + CSVInferSchema.infer(tokenRDD, header, parsedOptions) case None => // If the first row could not be read, just return the empty schema. - Some(StructType(Nil)) + StructType(Nil) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala index eef43c7629c1..a99bdfee5d6e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala @@ -51,12 +51,10 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { sparkSession: SparkSession, options: Map[String, String], files: Seq[FileStatus]): Option[StructType] = { - require(files.nonEmpty, "Cannot infer schema from an empty set of files") - val parsedOptions = new CSVOptions(options, sparkSession.sessionState.conf.sessionLocalTimeZone) - CSVDataSource(parsedOptions).infer(sparkSession, files, parsedOptions) + CSVDataSource(parsedOptions).inferSchema(sparkSession, files, parsedOptions) } override def prepareWrite( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala index 8a8ba0553452..8287776f8f55 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala @@ -370,9 +370,11 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with Be val schema = df.schema // Reader, without user specified schema - intercept[IllegalArgumentException] { + val message = intercept[AnalysisException] { testRead(spark.read.csv(), Seq.empty, schema) - } + }.getMessage + assert(message.contains("Unable to infer schema for CSV. It must be specified manually.")) + testRead(spark.read.csv(dir), data, schema) testRead(spark.read.csv(dir, dir), data ++ data, schema) testRead(spark.read.csv(Seq(dir, dir): _*), data ++ data, schema) From 2d73fcced0492c606feab8fe84f62e8318ebcaa1 Mon Sep 17 00:00:00 2001 From: Kunal Khamar Date: Tue, 21 Mar 2017 18:56:14 -0700 Subject: [PATCH 097/512] [SPARK-20051][SS] Fix StreamSuite flaky test - recover from v2.1 checkpoint ## What changes were proposed in this pull request? There is a race condition between calling stop on a streaming query and deleting directories in `withTempDir` that causes test to fail, fixing to do lazy deletion using delete on shutdown JVM hook. ## How was this patch tested? - Unit test - repeated 300 runs with no failure Author: Kunal Khamar Closes #17382 from kunalkhamar/partition-bugfix. --- .../spark/sql/streaming/StreamSuite.scala | 77 +++++++++---------- 1 file changed, 37 insertions(+), 40 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala index e867fc40f7f1..f01211e20cbf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala @@ -33,6 +33,7 @@ import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.StreamSourceProvider import org.apache.spark.sql.types.{IntegerType, StructField, StructType} +import org.apache.spark.util.Utils class StreamSuite extends StreamTest { @@ -438,52 +439,48 @@ class StreamSuite extends StreamTest { // 1 - Test if recovery from the checkpoint is successful. prepareMemoryStream() - withTempDir { dir => - // Copy the checkpoint to a temp dir to prevent changes to the original. - // Not doing this will lead to the test passing on the first run, but fail subsequent runs. - FileUtils.copyDirectory(checkpointDir, dir) - - // Checkpoint data was generated by a query with 10 shuffle partitions. - // In order to test reading from the checkpoint, the checkpoint must have two or more batches, - // since the last batch may be rerun. - withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "10") { - var streamingQuery: StreamingQuery = null - try { - streamingQuery = - query.queryName("counts").option("checkpointLocation", dir.getCanonicalPath).start() - streamingQuery.processAllAvailable() - inputData.addData(9) - streamingQuery.processAllAvailable() - - QueryTest.checkAnswer(spark.table("counts").toDF(), - Row("1", 1) :: Row("2", 1) :: Row("3", 2) :: Row("4", 2) :: - Row("5", 2) :: Row("6", 2) :: Row("7", 1) :: Row("8", 1) :: Row("9", 1) :: Nil) - } finally { - if (streamingQuery ne null) { - streamingQuery.stop() - } + val dir1 = Utils.createTempDir().getCanonicalFile // not using withTempDir {}, makes test flaky + // Copy the checkpoint to a temp dir to prevent changes to the original. + // Not doing this will lead to the test passing on the first run, but fail subsequent runs. + FileUtils.copyDirectory(checkpointDir, dir1) + // Checkpoint data was generated by a query with 10 shuffle partitions. + // In order to test reading from the checkpoint, the checkpoint must have two or more batches, + // since the last batch may be rerun. + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "10") { + var streamingQuery: StreamingQuery = null + try { + streamingQuery = + query.queryName("counts").option("checkpointLocation", dir1.getCanonicalPath).start() + streamingQuery.processAllAvailable() + inputData.addData(9) + streamingQuery.processAllAvailable() + + QueryTest.checkAnswer(spark.table("counts").toDF(), + Row("1", 1) :: Row("2", 1) :: Row("3", 2) :: Row("4", 2) :: + Row("5", 2) :: Row("6", 2) :: Row("7", 1) :: Row("8", 1) :: Row("9", 1) :: Nil) + } finally { + if (streamingQuery ne null) { + streamingQuery.stop() } } } // 2 - Check recovery with wrong num shuffle partitions prepareMemoryStream() - withTempDir { dir => - FileUtils.copyDirectory(checkpointDir, dir) - - // Since the number of partitions is greater than 10, should throw exception. - withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "15") { - var streamingQuery: StreamingQuery = null - try { - intercept[StreamingQueryException] { - streamingQuery = - query.queryName("badQuery").option("checkpointLocation", dir.getCanonicalPath).start() - streamingQuery.processAllAvailable() - } - } finally { - if (streamingQuery ne null) { - streamingQuery.stop() - } + val dir2 = Utils.createTempDir().getCanonicalFile + FileUtils.copyDirectory(checkpointDir, dir2) + // Since the number of partitions is greater than 10, should throw exception. + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "15") { + var streamingQuery: StreamingQuery = null + try { + intercept[StreamingQueryException] { + streamingQuery = + query.queryName("badQuery").option("checkpointLocation", dir2.getCanonicalPath).start() + streamingQuery.processAllAvailable() + } + } finally { + if (streamingQuery ne null) { + streamingQuery.stop() } } } From c1e87e384d1878308b42da80bb3d65be512aab55 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 21 Mar 2017 21:27:08 -0700 Subject: [PATCH 098/512] [SPARK-20030][SS] Event-time-based timeout for MapGroupsWithState ## What changes were proposed in this pull request? Adding event time based timeout. The user sets the timeout timestamp directly using `KeyedState.setTimeoutTimestamp`. The keys times out when the watermark crosses the timeout timestamp. ## How was this patch tested? Unit tests Author: Tathagata Das Closes #17361 from tdas/SPARK-20030. --- .../sql/streaming/KeyedStateTimeout.java | 22 +- .../UnsupportedOperationChecker.scala | 96 +++-- .../sql/catalyst/plans/logical/object.scala | 3 +- .../analysis/UnsupportedOperationsSuite.scala | 16 + .../spark/sql/execution/SparkStrategies.scala | 3 +- .../FlatMapGroupsWithStateExec.scala | 87 ++-- .../streaming/IncrementalExecution.scala | 5 +- .../execution/streaming/KeyedStateImpl.scala | 139 ++++-- .../streaming/statefulOperators.scala | 14 +- .../spark/sql/streaming/KeyedState.scala | 97 ++++- .../FlatMapGroupsWithStateSuite.scala | 402 ++++++++++++------ 11 files changed, 616 insertions(+), 268 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/KeyedStateTimeout.java b/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/KeyedStateTimeout.java index cf112f2e02a9..e2e7ab1d2609 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/KeyedStateTimeout.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/KeyedStateTimeout.java @@ -19,9 +19,7 @@ import org.apache.spark.annotation.Experimental; import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.catalyst.plans.logical.NoTimeout$; -import org.apache.spark.sql.catalyst.plans.logical.ProcessingTimeTimeout; -import org.apache.spark.sql.catalyst.plans.logical.ProcessingTimeTimeout$; +import org.apache.spark.sql.catalyst.plans.logical.*; /** * Represents the type of timeouts possible for the Dataset operations @@ -34,9 +32,23 @@ @InterfaceStability.Evolving public class KeyedStateTimeout { - /** Timeout based on processing time. */ + /** + * Timeout based on processing time. The duration of timeout can be set for each group in + * `map/flatMapGroupsWithState` by calling `KeyedState.setTimeoutDuration()`. See documentation + * on `KeyedState` for more details. + */ public static KeyedStateTimeout ProcessingTimeTimeout() { return ProcessingTimeTimeout$.MODULE$; } - /** No timeout */ + /** + * Timeout based on event-time. The event-time timestamp for timeout can be set for each + * group in `map/flatMapGroupsWithState` by calling `KeyedState.setTimeoutTimestamp()`. + * In addition, you have to define the watermark in the query using `Dataset.withWatermark`. + * When the watermark advances beyond the set timestamp of a group and the group has not + * received any data, then the group times out. See documentation on + * `KeyedState` for more details. + */ + public static KeyedStateTimeout EventTimeTimeout() { return EventTimeTimeout$.MODULE$; } + + /** No timeout. */ public static KeyedStateTimeout NoTimeout() { return NoTimeout$.MODULE$; } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala index a9ff61e0e880..7da7f55aa5d7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala @@ -147,49 +147,69 @@ object UnsupportedOperationChecker { throwError("Commands like CreateTable*, AlterTable*, Show* are not supported with " + "streaming DataFrames/Datasets") - // mapGroupsWithState: Allowed only when no aggregation + Update output mode - case m: FlatMapGroupsWithState if m.isStreaming && m.isMapGroupsWithState => - if (collectStreamingAggregates(plan).isEmpty) { - if (outputMode != InternalOutputModes.Update) { - throwError("mapGroupsWithState is not supported with " + - s"$outputMode output mode on a streaming DataFrame/Dataset") - } else { - // Allowed when no aggregation + Update output mode - } - } else { - throwError("mapGroupsWithState is not supported with aggregation " + - "on a streaming DataFrame/Dataset") - } - - // flatMapGroupsWithState without aggregation - case m: FlatMapGroupsWithState - if m.isStreaming && collectStreamingAggregates(plan).isEmpty => - m.outputMode match { - case InternalOutputModes.Update => - if (outputMode != InternalOutputModes.Update) { - throwError("flatMapGroupsWithState in update mode is not supported with " + + // mapGroupsWithState and flatMapGroupsWithState + case m: FlatMapGroupsWithState if m.isStreaming => + + // Check compatibility with output modes and aggregations in query + val aggsAfterFlatMapGroups = collectStreamingAggregates(plan) + + if (m.isMapGroupsWithState) { // check mapGroupsWithState + // allowed only in update query output mode and without aggregation + if (aggsAfterFlatMapGroups.nonEmpty) { + throwError( + "mapGroupsWithState is not supported with aggregation " + + "on a streaming DataFrame/Dataset") + } else if (outputMode != InternalOutputModes.Update) { + throwError( + "mapGroupsWithState is not supported with " + s"$outputMode output mode on a streaming DataFrame/Dataset") + } + } else { // check latMapGroupsWithState + if (aggsAfterFlatMapGroups.isEmpty) { + // flatMapGroupsWithState without aggregation: operation's output mode must + // match query output mode + m.outputMode match { + case InternalOutputModes.Update if outputMode != InternalOutputModes.Update => + throwError( + "flatMapGroupsWithState in update mode is not supported with " + + s"$outputMode output mode on a streaming DataFrame/Dataset") + + case InternalOutputModes.Append if outputMode != InternalOutputModes.Append => + throwError( + "flatMapGroupsWithState in append mode is not supported with " + + s"$outputMode output mode on a streaming DataFrame/Dataset") + + case _ => } - case InternalOutputModes.Append => - if (outputMode != InternalOutputModes.Append) { - throwError("flatMapGroupsWithState in append mode is not supported with " + - s"$outputMode output mode on a streaming DataFrame/Dataset") + } else { + // flatMapGroupsWithState with aggregation: update operation mode not allowed, and + // *groupsWithState after aggregation not allowed + if (m.outputMode == InternalOutputModes.Update) { + throwError( + "flatMapGroupsWithState in update mode is not supported with " + + "aggregation on a streaming DataFrame/Dataset") + } else if (collectStreamingAggregates(m).nonEmpty) { + throwError( + "flatMapGroupsWithState in append mode is not supported after " + + s"aggregation on a streaming DataFrame/Dataset") } + } } - // flatMapGroupsWithState(Update) with aggregation - case m: FlatMapGroupsWithState - if m.isStreaming && m.outputMode == InternalOutputModes.Update - && collectStreamingAggregates(plan).nonEmpty => - throwError("flatMapGroupsWithState in update mode is not supported with " + - "aggregation on a streaming DataFrame/Dataset") - - // flatMapGroupsWithState(Append) with aggregation - case m: FlatMapGroupsWithState - if m.isStreaming && m.outputMode == InternalOutputModes.Append - && collectStreamingAggregates(m).nonEmpty => - throwError("flatMapGroupsWithState in append mode is not supported after " + - s"aggregation on a streaming DataFrame/Dataset") + // Check compatibility with timeout configs + if (m.timeout == EventTimeTimeout) { + // With event time timeout, watermark must be defined. + val watermarkAttributes = m.child.output.collect { + case a: Attribute if a.metadata.contains(EventTimeWatermark.delayKey) => a + } + if (watermarkAttributes.isEmpty) { + throwError( + "Watermark must be specified in the query using " + + "'[Dataset/DataFrame].withWatermark()' for using event-time timeout in a " + + "[map|flatMap]GroupsWithState. Event-time timeout not supported without " + + "watermark.")(plan) + } + } case d: Deduplicate if collectStreamingAggregates(d).nonEmpty => throwError("dropDuplicates is not supported after aggregation on a " + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala index d1f95faf2db0..e0ecf8c5f264 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala @@ -353,9 +353,10 @@ case class MapGroups( /** Internal class representing State */ trait LogicalKeyedState[S] -/** Possible types of timeouts used in FlatMapGroupsWithState */ +/** Types of timeouts used in FlatMapGroupsWithState */ case object NoTimeout extends KeyedStateTimeout case object ProcessingTimeTimeout extends KeyedStateTimeout +case object EventTimeTimeout extends KeyedStateTimeout /** Factory for constructing new `MapGroupsWithState` nodes. */ object FlatMapGroupsWithState { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala index 08216e266040..8f0a0c0d99d1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala @@ -345,6 +345,22 @@ class UnsupportedOperationsSuite extends SparkFunSuite { outputMode = Append, expectedMsgs = Seq("Mixing mapGroupsWithStates and flatMapGroupsWithStates")) + // mapGroupsWithState with event time timeout + watermark + assertNotSupportedInStreamingPlan( + "mapGroupsWithState - mapGroupsWithState with event time timeout without watermark", + FlatMapGroupsWithState( + null, att, att, Seq(att), Seq(att), att, null, Update, isMapGroupsWithState = true, + EventTimeTimeout, streamRelation), + outputMode = Update, + expectedMsgs = Seq("watermark")) + + assertSupportedInStreamingPlan( + "mapGroupsWithState - mapGroupsWithState with event time timeout with watermark", + FlatMapGroupsWithState( + null, att, att, Seq(att), Seq(att), att, null, Update, isMapGroupsWithState = true, + EventTimeTimeout, new TestStreamingRelation(attributeWithWatermark)), + outputMode = Update) + // Deduplicate assertSupportedInStreamingPlan( "Deduplicate - Deduplicate on streaming relation before aggregation", diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 9e58e8ce3d5f..ca2f6dd7a84b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -336,8 +336,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { timeout, child) => val execPlan = FlatMapGroupsWithStateExec( func, keyDeser, valueDeser, groupAttr, dataAttr, outputAttr, None, stateEnc, outputMode, - timeout, batchTimestampMs = KeyedStateImpl.NO_BATCH_PROCESSING_TIMESTAMP, - planLater(child)) + timeout, batchTimestampMs = None, eventTimeWatermark = None, planLater(child)) execPlan :: Nil case _ => Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala index 991d8ef70756..52ad70c7dc88 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala @@ -19,13 +19,14 @@ package org.apache.spark.sql.execution.streaming import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder -import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, AttributeReference, Expression, Literal, SortOrder, SpecificInternalRow, UnsafeProjection, UnsafeRow} -import org.apache.spark.sql.catalyst.plans.logical.{LogicalKeyedState, ProcessingTimeTimeout} -import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution, Partitioning} +import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, AttributeReference, Expression, Literal, SortOrder, UnsafeRow} +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution} import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.streaming.KeyedStateImpl.NO_TIMESTAMP import org.apache.spark.sql.execution.streaming.state._ import org.apache.spark.sql.streaming.{KeyedStateTimeout, OutputMode} -import org.apache.spark.sql.types.{BooleanType, IntegerType} +import org.apache.spark.sql.types.IntegerType import org.apache.spark.util.CompletionIterator /** @@ -39,7 +40,7 @@ import org.apache.spark.util.CompletionIterator * @param outputObjAttr used to define the output object * @param stateEncoder used to serialize/deserialize state before calling `func` * @param outputMode the output mode of `func` - * @param timeout used to timeout groups that have not received data in a while + * @param timeoutConf used to timeout groups that have not received data in a while * @param batchTimestampMs processing timestamp of the current batch. */ case class FlatMapGroupsWithStateExec( @@ -52,11 +53,15 @@ case class FlatMapGroupsWithStateExec( stateId: Option[OperatorStateId], stateEncoder: ExpressionEncoder[Any], outputMode: OutputMode, - timeout: KeyedStateTimeout, - batchTimestampMs: Long, - child: SparkPlan) extends UnaryExecNode with ObjectProducerExec with StateStoreWriter { + timeoutConf: KeyedStateTimeout, + batchTimestampMs: Option[Long], + override val eventTimeWatermark: Option[Long], + child: SparkPlan + ) extends UnaryExecNode with ObjectProducerExec with StateStoreWriter with WatermarkSupport { - private val isTimeoutEnabled = timeout == ProcessingTimeTimeout + import KeyedStateImpl._ + + private val isTimeoutEnabled = timeoutConf != NoTimeout private val timestampTimeoutAttribute = AttributeReference("timeoutTimestamp", dataType = IntegerType, nullable = false)() private val stateAttributes: Seq[Attribute] = { @@ -64,8 +69,6 @@ case class FlatMapGroupsWithStateExec( if (isTimeoutEnabled) encSchemaAttribs :+ timestampTimeoutAttribute else encSchemaAttribs } - import KeyedStateImpl._ - /** Distribute by grouping attributes */ override def requiredChildDistribution: Seq[Distribution] = ClusteredDistribution(groupingAttributes) :: Nil @@ -74,9 +77,21 @@ case class FlatMapGroupsWithStateExec( override def requiredChildOrdering: Seq[Seq[SortOrder]] = Seq(groupingAttributes.map(SortOrder(_, Ascending))) + override def keyExpressions: Seq[Attribute] = groupingAttributes + override protected def doExecute(): RDD[InternalRow] = { metrics // force lazy init at driver + // Throw errors early if parameters are not as expected + timeoutConf match { + case ProcessingTimeTimeout => + require(batchTimestampMs.nonEmpty) + case EventTimeTimeout => + require(eventTimeWatermark.nonEmpty) // watermark value has been populated + require(watermarkExpression.nonEmpty) // input schema has watermark attribute + case _ => + } + child.execute().mapPartitionsWithStateStore[InternalRow]( getStateId.checkpointLocation, getStateId.operatorId, @@ -84,15 +99,23 @@ case class FlatMapGroupsWithStateExec( groupingAttributes.toStructType, stateAttributes.toStructType, sqlContext.sessionState, - Some(sqlContext.streams.stateStoreCoordinator)) { case (store, iterator) => + Some(sqlContext.streams.stateStoreCoordinator)) { case (store, iter) => val updater = new StateStoreUpdater(store) + // If timeout is based on event time, then filter late data based on watermark + val filteredIter = watermarkPredicateForData match { + case Some(predicate) if timeoutConf == EventTimeTimeout => + iter.filter(row => !predicate.eval(row)) + case None => + iter + } + // Generate a iterator that returns the rows grouped by the grouping function // Note that this code ensures that the filtering for timeout occurs only after // all the data has been processed. This is to ensure that the timeout information of all // the keys with data is updated before they are processed for timeouts. val outputIterator = - updater.updateStateForKeysWithData(iterator) ++ updater.updateStateForTimedOutKeys() + updater.updateStateForKeysWithData(filteredIter) ++ updater.updateStateForTimedOutKeys() // Return an iterator of all the rows generated by all the keys, such that when fully // consumed, all the state updates will be committed by the state store @@ -124,7 +147,7 @@ case class FlatMapGroupsWithStateExec( private val stateSerializer = { val encoderSerializer = stateEncoder.namedExpressions if (isTimeoutEnabled) { - encoderSerializer :+ Literal(KeyedStateImpl.TIMEOUT_TIMESTAMP_NOT_SET) + encoderSerializer :+ Literal(KeyedStateImpl.NO_TIMESTAMP) } else { encoderSerializer } @@ -157,16 +180,19 @@ case class FlatMapGroupsWithStateExec( /** Find the groups that have timeout set and are timing out right now, and call the function */ def updateStateForTimedOutKeys(): Iterator[InternalRow] = { if (isTimeoutEnabled) { + val timeoutThreshold = timeoutConf match { + case ProcessingTimeTimeout => batchTimestampMs.get + case EventTimeTimeout => eventTimeWatermark.get + case _ => + throw new IllegalStateException( + s"Cannot filter timed out keys for $timeoutConf") + } val timingOutKeys = store.filter { case (_, stateRow) => val timeoutTimestamp = getTimeoutTimestamp(stateRow) - timeoutTimestamp != TIMEOUT_TIMESTAMP_NOT_SET && timeoutTimestamp < batchTimestampMs + timeoutTimestamp != NO_TIMESTAMP && timeoutTimestamp < timeoutThreshold } timingOutKeys.flatMap { case (keyRow, stateRow) => - callFunctionAndUpdateState( - keyRow, - Iterator.empty, - Some(stateRow), - hasTimedOut = true) + callFunctionAndUpdateState(keyRow, Iterator.empty, Some(stateRow), hasTimedOut = true) } } else Iterator.empty } @@ -186,7 +212,11 @@ case class FlatMapGroupsWithStateExec( val valueObjIter = valueRowIter.map(getValueObj.apply) // convert value rows to objects val stateObjOption = getStateObj(prevStateRowOption) val keyedState = new KeyedStateImpl( - stateObjOption, batchTimestampMs, isTimeoutEnabled, hasTimedOut) + stateObjOption, + batchTimestampMs.getOrElse(NO_TIMESTAMP), + eventTimeWatermark.getOrElse(NO_TIMESTAMP), + timeoutConf, + hasTimedOut) // Call function, get the returned objects and convert them to rows val mappedIterator = func(keyObj, valueObjIter, keyedState).map { obj => @@ -196,8 +226,6 @@ case class FlatMapGroupsWithStateExec( // When the iterator is consumed, then write changes to state def onIteratorCompletion: Unit = { - // Has the timeout information changed - if (keyedState.hasRemoved) { store.remove(keyRow) numUpdatedStateRows += 1 @@ -205,26 +233,25 @@ case class FlatMapGroupsWithStateExec( } else { val previousTimeoutTimestamp = prevStateRowOption match { case Some(row) => getTimeoutTimestamp(row) - case None => TIMEOUT_TIMESTAMP_NOT_SET + case None => NO_TIMESTAMP } - + val currentTimeoutTimestamp = keyedState.getTimeoutTimestamp val stateRowToWrite = if (keyedState.hasUpdated) { getStateRow(keyedState.get) } else { prevStateRowOption.orNull } - val hasTimeoutChanged = keyedState.getTimeoutTimestamp != previousTimeoutTimestamp + val hasTimeoutChanged = currentTimeoutTimestamp != previousTimeoutTimestamp val shouldWriteState = keyedState.hasUpdated || hasTimeoutChanged if (shouldWriteState) { if (stateRowToWrite == null) { // This should never happen because checks in KeyedStateImpl should avoid cases // where empty state would need to be written - throw new IllegalStateException( - "Attempting to write empty state") + throw new IllegalStateException("Attempting to write empty state") } - setTimeoutTimestamp(stateRowToWrite, keyedState.getTimeoutTimestamp) + setTimeoutTimestamp(stateRowToWrite, currentTimeoutTimestamp) store.put(keyRow.copy(), stateRowToWrite.copy()) numUpdatedStateRows += 1 } @@ -247,7 +274,7 @@ case class FlatMapGroupsWithStateExec( /** Returns the timeout timestamp of a state row is set */ def getTimeoutTimestamp(stateRow: UnsafeRow): Long = { - if (isTimeoutEnabled) stateRow.getLong(timeoutTimestampIndex) else TIMEOUT_TIMESTAMP_NOT_SET + if (isTimeoutEnabled) stateRow.getLong(timeoutTimestampIndex) else NO_TIMESTAMP } /** Set the timestamp in a state row */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index a934c75a0245..0f0e4a91f8cc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -108,7 +108,10 @@ class IncrementalExecution( case m: FlatMapGroupsWithStateExec => val stateId = OperatorStateId(checkpointLocation, operatorId.getAndIncrement(), currentBatchId) - m.copy(stateId = Some(stateId), batchTimestampMs = offsetSeqMetadata.batchTimestampMs) + m.copy( + stateId = Some(stateId), + batchTimestampMs = Some(offsetSeqMetadata.batchTimestampMs), + eventTimeWatermark = Some(offsetSeqMetadata.batchWatermarkMs)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/KeyedStateImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/KeyedStateImpl.scala index ac421d395beb..edfd35bd5dd7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/KeyedStateImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/KeyedStateImpl.scala @@ -17,37 +17,45 @@ package org.apache.spark.sql.execution.streaming +import java.sql.Date + import org.apache.commons.lang3.StringUtils -import org.apache.spark.sql.streaming.KeyedState +import org.apache.spark.sql.catalyst.plans.logical.{EventTimeTimeout, ProcessingTimeTimeout} +import org.apache.spark.sql.execution.streaming.KeyedStateImpl._ +import org.apache.spark.sql.streaming.{KeyedState, KeyedStateTimeout} import org.apache.spark.unsafe.types.CalendarInterval + /** * Internal implementation of the [[KeyedState]] interface. Methods are not thread-safe. * @param optionalValue Optional value of the state * @param batchProcessingTimeMs Processing time of current batch, used to calculate timestamp * for processing time timeouts - * @param isTimeoutEnabled Whether timeout is enabled. This will be used to check whether the user - * is allowed to configure timeouts. + * @param timeoutConf Type of timeout configured. Based on this, different operations will + * be supported. * @param hasTimedOut Whether the key for which this state wrapped is being created is * getting timed out or not. */ private[sql] class KeyedStateImpl[S]( optionalValue: Option[S], batchProcessingTimeMs: Long, - isTimeoutEnabled: Boolean, + eventTimeWatermarkMs: Long, + timeoutConf: KeyedStateTimeout, override val hasTimedOut: Boolean) extends KeyedState[S] { - import KeyedStateImpl._ - // Constructor to create dummy state when using mapGroupsWithState in a batch query def this(optionalValue: Option[S]) = this( - optionalValue, -1, isTimeoutEnabled = false, hasTimedOut = false) + optionalValue, + batchProcessingTimeMs = NO_TIMESTAMP, + eventTimeWatermarkMs = NO_TIMESTAMP, + timeoutConf = KeyedStateTimeout.NoTimeout, + hasTimedOut = false) private var value: S = optionalValue.getOrElse(null.asInstanceOf[S]) private var defined: Boolean = optionalValue.isDefined private var updated: Boolean = false // whether value has been updated (but not removed) private var removed: Boolean = false // whether value has been removed - private var timeoutTimestamp: Long = TIMEOUT_TIMESTAMP_NOT_SET + private var timeoutTimestamp: Long = NO_TIMESTAMP // ========= Public API ========= override def exists: Boolean = defined @@ -82,13 +90,14 @@ private[sql] class KeyedStateImpl[S]( defined = false updated = false removed = true - timeoutTimestamp = TIMEOUT_TIMESTAMP_NOT_SET + timeoutTimestamp = NO_TIMESTAMP } override def setTimeoutDuration(durationMs: Long): Unit = { - if (!isTimeoutEnabled) { + if (timeoutConf != ProcessingTimeTimeout) { throw new UnsupportedOperationException( - "Cannot set timeout information without enabling timeout in map/flatMapGroupsWithState") + "Cannot set timeout duration without enabling processing time timeout in " + + "map/flatMapGroupsWithState") } if (!defined) { throw new IllegalStateException( @@ -99,7 +108,7 @@ private[sql] class KeyedStateImpl[S]( if (durationMs <= 0) { throw new IllegalArgumentException("Timeout duration must be positive") } - if (!removed && batchProcessingTimeMs != NO_BATCH_PROCESSING_TIMESTAMP) { + if (!removed && batchProcessingTimeMs != NO_TIMESTAMP) { timeoutTimestamp = durationMs + batchProcessingTimeMs } else { // This is being called in a batch query, hence no processing timestamp. @@ -108,29 +117,55 @@ private[sql] class KeyedStateImpl[S]( } override def setTimeoutDuration(duration: String): Unit = { - if (StringUtils.isBlank(duration)) { - throw new IllegalArgumentException( - "The window duration, slide duration and start time cannot be null or blank.") - } - val intervalString = if (duration.startsWith("interval")) { - duration - } else { - "interval " + duration + setTimeoutDuration(parseDuration(duration)) + } + + @throws[IllegalArgumentException]("if 'timestampMs' is not positive") + @throws[IllegalStateException]("when state is either not initialized, or already removed") + @throws[UnsupportedOperationException]( + "if 'timeout' has not been enabled in [map|flatMap]GroupsWithState in a streaming query") + override def setTimeoutTimestamp(timestampMs: Long): Unit = { + checkTimeoutTimestampAllowed() + if (timestampMs <= 0) { + throw new IllegalArgumentException("Timeout timestamp must be positive") } - val cal = CalendarInterval.fromString(intervalString) - if (cal == null) { + if (eventTimeWatermarkMs != NO_TIMESTAMP && timestampMs < eventTimeWatermarkMs) { throw new IllegalArgumentException( - s"The provided duration ($duration) is not valid.") + s"Timeout timestamp ($timestampMs) cannot be earlier than the " + + s"current watermark ($eventTimeWatermarkMs)") } - if (cal.milliseconds < 0 || cal.months < 0) { - throw new IllegalArgumentException("Timeout duration must be positive") + if (!removed && batchProcessingTimeMs != NO_TIMESTAMP) { + timeoutTimestamp = timestampMs + } else { + // This is being called in a batch query, hence no processing timestamp. + // Just ignore any attempts to set timeout. } + } - val delayMs = { - val millisPerMonth = CalendarInterval.MICROS_PER_DAY / 1000 * 31 - cal.milliseconds + cal.months * millisPerMonth - } - setTimeoutDuration(delayMs) + @throws[IllegalArgumentException]("if 'additionalDuration' is invalid") + @throws[IllegalStateException]("when state is either not initialized, or already removed") + @throws[UnsupportedOperationException]( + "if 'timeout' has not been enabled in [map|flatMap]GroupsWithState in a streaming query") + override def setTimeoutTimestamp(timestampMs: Long, additionalDuration: String): Unit = { + checkTimeoutTimestampAllowed() + setTimeoutTimestamp(parseDuration(additionalDuration) + timestampMs) + } + + @throws[IllegalStateException]("when state is either not initialized, or already removed") + @throws[UnsupportedOperationException]( + "if 'timeout' has not been enabled in [map|flatMap]GroupsWithState in a streaming query") + override def setTimeoutTimestamp(timestamp: Date): Unit = { + checkTimeoutTimestampAllowed() + setTimeoutTimestamp(timestamp.getTime) + } + + @throws[IllegalArgumentException]("if 'additionalDuration' is invalid") + @throws[IllegalStateException]("when state is either not initialized, or already removed") + @throws[UnsupportedOperationException]( + "if 'timeout' has not been enabled in [map|flatMap]GroupsWithState in a streaming query") + override def setTimeoutTimestamp(timestamp: Date, additionalDuration: String): Unit = { + checkTimeoutTimestampAllowed() + setTimeoutTimestamp(timestamp.getTime + parseDuration(additionalDuration)) } override def toString: String = { @@ -147,14 +182,46 @@ private[sql] class KeyedStateImpl[S]( /** Return timeout timestamp or `TIMEOUT_TIMESTAMP_NOT_SET` if not set */ def getTimeoutTimestamp: Long = timeoutTimestamp + + private def parseDuration(duration: String): Long = { + if (StringUtils.isBlank(duration)) { + throw new IllegalArgumentException( + "Provided duration is null or blank.") + } + val intervalString = if (duration.startsWith("interval")) { + duration + } else { + "interval " + duration + } + val cal = CalendarInterval.fromString(intervalString) + if (cal == null) { + throw new IllegalArgumentException( + s"Provided duration ($duration) is not valid.") + } + if (cal.milliseconds < 0 || cal.months < 0) { + throw new IllegalArgumentException(s"Provided duration ($duration) is not positive") + } + + val millisPerMonth = CalendarInterval.MICROS_PER_DAY / 1000 * 31 + cal.milliseconds + cal.months * millisPerMonth + } + + private def checkTimeoutTimestampAllowed(): Unit = { + if (timeoutConf != EventTimeTimeout) { + throw new UnsupportedOperationException( + "Cannot set timeout timestamp without enabling event time timeout in " + + "map/flatMapGroupsWithState") + } + if (!defined) { + throw new IllegalStateException( + "Cannot set timeout timestamp without any state value, " + + "state has either not been initialized, or has already been removed") + } + } } private[sql] object KeyedStateImpl { - // Value used in the state row to represent the lack of any timeout timestamp - val TIMEOUT_TIMESTAMP_NOT_SET = -1L - - // Value to represent that no batch processing timestamp is passed to KeyedStateImpl. This is - // used in batch queries where there are no streaming batches and timeouts. - val NO_BATCH_PROCESSING_TIMESTAMP = -1L + // Value used represent the lack of valid timestamp as a long + val NO_TIMESTAMP = -1L } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index 6d2de441eb44..f72144a25d5c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -80,7 +80,7 @@ trait WatermarkSupport extends UnaryExecNode { /** Generate an expression that matches data older than the watermark */ lazy val watermarkExpression: Option[Expression] = { val optionalWatermarkAttribute = - keyExpressions.find(_.metadata.contains(EventTimeWatermark.delayKey)) + child.output.find(_.metadata.contains(EventTimeWatermark.delayKey)) optionalWatermarkAttribute.map { watermarkAttribute => // If we are evicting based on a window, use the end of the window. Otherwise just @@ -101,14 +101,12 @@ trait WatermarkSupport extends UnaryExecNode { } } - /** Generate a predicate based on keys that matches data older than the watermark */ + /** Predicate based on keys that matches data older than the watermark */ lazy val watermarkPredicateForKeys: Option[Predicate] = watermarkExpression.map(newPredicate(_, keyExpressions)) - /** - * Generate a predicate based on the child output that matches data older than the watermark. - */ - lazy val watermarkPredicate: Option[Predicate] = + /** Predicate based on the child output that matches data older than the watermark. */ + lazy val watermarkPredicateForData: Option[Predicate] = watermarkExpression.map(newPredicate(_, child.output)) } @@ -218,7 +216,7 @@ case class StateStoreSaveExec( new Iterator[InternalRow] { // Filter late date using watermark if specified - private[this] val baseIterator = watermarkPredicate match { + private[this] val baseIterator = watermarkPredicateForData match { case Some(predicate) => iter.filter((row: InternalRow) => !predicate.eval(row)) case None => iter } @@ -285,7 +283,7 @@ case class StreamingDeduplicateExec( val numTotalStateRows = longMetric("numTotalStateRows") val numUpdatedStateRows = longMetric("numUpdatedStateRows") - val baseIterator = watermarkPredicate match { + val baseIterator = watermarkPredicateForData match { case Some(predicate) => iter.filter(row => !predicate.eval(row)) case None => iter } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/KeyedState.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/KeyedState.scala index 6b4b1ced98a3..461de04f6bbe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/KeyedState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/KeyedState.scala @@ -55,7 +55,7 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalKeyedState * batch, nor with streaming Datasets. * - All the data will be shuffled before applying the function. * - If timeout is set, then the function will also be called with no values. - * See more details on KeyedStateTimeout` below. + * See more details on `KeyedStateTimeout` below. * * Important points to note about using `KeyedState`. * - The value of the state cannot be null. So updating state with null will throw @@ -68,20 +68,38 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalKeyedState * * Important points to note about using `KeyedStateTimeout`. * - The timeout type is a global param across all the keys (set as `timeout` param in - * `[map|flatMap]GroupsWithState`, but the exact timeout duration is configurable per key - * (by calling `setTimeout...()` in `KeyedState`). - * - When the timeout occurs for a key, the function is called with no values, and + * `[map|flatMap]GroupsWithState`, but the exact timeout duration/timestamp is configurable per + * key by calling `setTimeout...()` in `KeyedState`. + * - Timeouts can be either based on processing time (i.e. + * [[KeyedStateTimeout.ProcessingTimeTimeout]]) or event time (i.e. + * [[KeyedStateTimeout.EventTimeTimeout]]). + * - With `ProcessingTimeTimeout`, the timeout duration can be set by calling + * `KeyedState.setTimeoutDuration`. The timeout will occur when the clock has advanced by the set + * duration. Guarantees provided by this timeout with a duration of D ms are as follows: + * - Timeout will never be occur before the clock time has advanced by D ms + * - Timeout will occur eventually when there is a trigger in the query + * (i.e. after D ms). So there is a no strict upper bound on when the timeout would occur. + * For example, the trigger interval of the query will affect when the timeout actually occurs. + * If there is no data in the stream (for any key) for a while, then their will not be + * any trigger and timeout function call will not occur until there is data. + * - Since the processing time timeout is based on the clock time, it is affected by the + * variations in the system clock (i.e. time zone changes, clock skew, etc.). + * - With `EventTimeTimeout`, the user also has to specify the the the event time watermark in + * the query using `Dataset.withWatermark()`. With this setting, data that is older than the + * watermark are filtered out. The timeout can be enabled for a key by setting a timestamp using + * `KeyedState.setTimeoutTimestamp()`, and the timeout would occur when the watermark advances + * beyond the set timestamp. You can control the timeout delay by two parameters - (i) watermark + * delay and an additional duration beyond the timestamp in the event (which is guaranteed to + * > watermark due to the filtering). Guarantees provided by this timeout are as follows: + * - Timeout will never be occur before watermark has exceeded the set timeout. + * - Similar to processing time timeouts, there is a no strict upper bound on the delay when + * the timeout actually occurs. The watermark can advance only when there is data in the + * stream, and the event time of the data has actually advanced. + * - When the timeout occurs for a key, the function is called for that key with no values, and * `KeyedState.hasTimedOut()` set to true. * - The timeout is reset for key every time the function is called on the key, that is, * when the key has new data, or the key has timed out. So the user has to set the timeout * duration every time the function is called, otherwise there will not be any timeout set. - * - Guarantees provided on processing-time-based timeout of key, when timeout duration is D ms: - * - Timeout will never be called before real clock time has advanced by D ms - * - Timeout will be called eventually when there is a trigger in the query - * (i.e. after D ms). So there is a no strict upper bound on when the timeout would occur. - * For example, the trigger interval of the query will affect when the timeout is actually hit. - * If there is no data in the stream (for any key) for a while, then their will not be - * any trigger and timeout will not be hit until there is data. * * Scala example of using KeyedState in `mapGroupsWithState`: * {{{ @@ -194,7 +212,8 @@ trait KeyedState[S] extends LogicalKeyedState[S] { /** * Set the timeout duration in ms for this key. - * @note Timeouts must be enabled in `[map/flatmap]GroupsWithStates`. + * + * @note ProcessingTimeTimeout must be enabled in `[map/flatmap]GroupsWithStates`. */ @throws[IllegalArgumentException]("if 'durationMs' is not positive") @throws[IllegalStateException]("when state is either not initialized, or already removed") @@ -204,11 +223,63 @@ trait KeyedState[S] extends LogicalKeyedState[S] { /** * Set the timeout duration for this key as a string. For example, "1 hour", "2 days", etc. - * @note, Timeouts must be enabled in `[map/flatmap]GroupsWithStates`. + * + * @note, ProcessingTimeTimeout must be enabled in `[map/flatmap]GroupsWithStates`. */ @throws[IllegalArgumentException]("if 'duration' is not a valid duration") @throws[IllegalStateException]("when state is either not initialized, or already removed") @throws[UnsupportedOperationException]( "if 'timeout' has not been enabled in [map|flatMap]GroupsWithState in a streaming query") def setTimeoutDuration(duration: String): Unit + + @throws[IllegalArgumentException]("if 'timestampMs' is not positive") + @throws[IllegalStateException]("when state is either not initialized, or already removed") + @throws[UnsupportedOperationException]( + "if 'timeout' has not been enabled in [map|flatMap]GroupsWithState in a streaming query") + /** + * Set the timeout timestamp for this key as milliseconds in epoch time. + * This timestamp cannot be older than the current watermark. + * + * @note, EventTimeTimeout must be enabled in `[map/flatmap]GroupsWithStates`. + */ + def setTimeoutTimestamp(timestampMs: Long): Unit + + @throws[IllegalArgumentException]("if 'additionalDuration' is invalid") + @throws[IllegalStateException]("when state is either not initialized, or already removed") + @throws[UnsupportedOperationException]( + "if 'timeout' has not been enabled in [map|flatMap]GroupsWithState in a streaming query") + /** + * Set the timeout timestamp for this key as milliseconds in epoch time and an additional + * duration as a string (e.g. "1 hour", "2 days", etc.). + * The final timestamp (including the additional duration) cannot be older than the + * current watermark. + * + * @note, EventTimeTimeout must be enabled in `[map/flatmap]GroupsWithStates`. + */ + def setTimeoutTimestamp(timestampMs: Long, additionalDuration: String): Unit + + @throws[IllegalStateException]("when state is either not initialized, or already removed") + @throws[UnsupportedOperationException]( + "if 'timeout' has not been enabled in [map|flatMap]GroupsWithState in a streaming query") + /** + * Set the timeout timestamp for this key as a java.sql.Date. + * This timestamp cannot be older than the current watermark. + * + * @note, EventTimeTimeout must be enabled in `[map/flatmap]GroupsWithStates`. + */ + def setTimeoutTimestamp(timestamp: java.sql.Date): Unit + + @throws[IllegalArgumentException]("if 'additionalDuration' is invalid") + @throws[IllegalStateException]("when state is either not initialized, or already removed") + @throws[UnsupportedOperationException]( + "if 'timeout' has not been enabled in [map|flatMap]GroupsWithState in a streaming query") + /** + * Set the timeout timestamp for this key as a java.sql.Date and an additional + * duration as a string (e.g. "1 hour", "2 days", etc.). + * The final timestamp (including the additional duration) cannot be older than the + * current watermark. + * + * @note, EventTimeTimeout must be enabled in `[map/flatmap]GroupsWithStates`. + */ + def setTimeoutTimestamp(timestamp: java.sql.Date, additionalDuration: String): Unit } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala index 7daa5e6a0f61..fe72283bb608 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.streaming -import java.util +import java.sql.Date import java.util.concurrent.ConcurrentHashMap import org.scalatest.BeforeAndAfterAll @@ -44,6 +44,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf import testImplicits._ import KeyedStateImpl._ + import KeyedStateTimeout._ override def afterAll(): Unit = { super.afterAll() @@ -96,77 +97,93 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf } } - test("KeyedState - setTimeoutDuration, hasTimedOut") { - import KeyedStateImpl._ - var state: KeyedStateImpl[Int] = null - - // When isTimeoutEnabled = false, then setTimeoutDuration() is not allowed + test("KeyedState - setTimeout**** with NoTimeout") { for (initState <- Seq(None, Some(5))) { // for different initial state - state = new KeyedStateImpl(initState, 1000, isTimeoutEnabled = false, hasTimedOut = false) - assert(state.hasTimedOut === false) - assert(state.getTimeoutTimestamp === TIMEOUT_TIMESTAMP_NOT_SET) - intercept[UnsupportedOperationException] { - state.setTimeoutDuration(1000) - } - intercept[UnsupportedOperationException] { - state.setTimeoutDuration("1 day") - } - assert(state.getTimeoutTimestamp === TIMEOUT_TIMESTAMP_NOT_SET) + implicit val state = new KeyedStateImpl(initState, 1000, 1000, NoTimeout, hasTimedOut = false) + testTimeoutDurationNotAllowed[UnsupportedOperationException](state) + testTimeoutTimestampNotAllowed[UnsupportedOperationException](state) } + } - def testTimeoutNotAllowed(): Unit = { - intercept[IllegalStateException] { - state.setTimeoutDuration(1000) - } - assert(state.getTimeoutTimestamp === TIMEOUT_TIMESTAMP_NOT_SET) - intercept[IllegalStateException] { - state.setTimeoutDuration("2 second") - } - assert(state.getTimeoutTimestamp === TIMEOUT_TIMESTAMP_NOT_SET) - } + test("KeyedState - setTimeout**** with ProcessingTimeTimeout") { + implicit var state: KeyedStateImpl[Int] = null - // When isTimeoutEnabled = true, then setTimeoutDuration() is not allowed until the - // state is be defined - state = new KeyedStateImpl(None, 1000, isTimeoutEnabled = true, hasTimedOut = false) - assert(state.hasTimedOut === false) - assert(state.getTimeoutTimestamp === TIMEOUT_TIMESTAMP_NOT_SET) - testTimeoutNotAllowed() + state = new KeyedStateImpl[Int](None, 1000, 1000, ProcessingTimeTimeout, hasTimedOut = false) + assert(state.getTimeoutTimestamp === NO_TIMESTAMP) + testTimeoutDurationNotAllowed[IllegalStateException](state) + testTimeoutTimestampNotAllowed[UnsupportedOperationException](state) - // After state has been set, setTimeoutDuration() is allowed, and - // getTimeoutTimestamp returned correct timestamp state.update(5) - assert(state.hasTimedOut === false) - assert(state.getTimeoutTimestamp === TIMEOUT_TIMESTAMP_NOT_SET) + assert(state.getTimeoutTimestamp === NO_TIMESTAMP) state.setTimeoutDuration(1000) assert(state.getTimeoutTimestamp === 2000) state.setTimeoutDuration("2 second") assert(state.getTimeoutTimestamp === 3000) - assert(state.hasTimedOut === false) + testTimeoutTimestampNotAllowed[UnsupportedOperationException](state) + + state.remove() + assert(state.getTimeoutTimestamp === NO_TIMESTAMP) + testTimeoutDurationNotAllowed[IllegalStateException](state) + testTimeoutTimestampNotAllowed[UnsupportedOperationException](state) + } + + test("KeyedState - setTimeout**** with EventTimeTimeout") { + implicit val state = new KeyedStateImpl[Int]( + None, 1000, 1000, EventTimeTimeout, hasTimedOut = false) + assert(state.getTimeoutTimestamp === NO_TIMESTAMP) + testTimeoutDurationNotAllowed[UnsupportedOperationException](state) + testTimeoutTimestampNotAllowed[IllegalStateException](state) + + state.update(5) + state.setTimeoutTimestamp(10000) + assert(state.getTimeoutTimestamp === 10000) + state.setTimeoutTimestamp(new Date(20000)) + assert(state.getTimeoutTimestamp === 20000) + testTimeoutDurationNotAllowed[UnsupportedOperationException](state) + + state.remove() + assert(state.getTimeoutTimestamp === NO_TIMESTAMP) + testTimeoutDurationNotAllowed[UnsupportedOperationException](state) + testTimeoutTimestampNotAllowed[IllegalStateException](state) + } + + test("KeyedState - illegal params to setTimeout****") { + var state: KeyedStateImpl[Int] = null - // setTimeoutDuration() with negative values or 0 is not allowed + // Test setTimeout****() with illegal values def testIllegalTimeout(body: => Unit): Unit = { intercept[IllegalArgumentException] { body } - assert(state.getTimeoutTimestamp === TIMEOUT_TIMESTAMP_NOT_SET) + assert(state.getTimeoutTimestamp === NO_TIMESTAMP) } - state = new KeyedStateImpl(Some(5), 1000, isTimeoutEnabled = true, hasTimedOut = false) + + state = new KeyedStateImpl(Some(5), 1000, 1000, ProcessingTimeTimeout, hasTimedOut = false) testIllegalTimeout { state.setTimeoutDuration(-1000) } testIllegalTimeout { state.setTimeoutDuration(0) } testIllegalTimeout { state.setTimeoutDuration("-2 second") } testIllegalTimeout { state.setTimeoutDuration("-1 month") } testIllegalTimeout { state.setTimeoutDuration("1 month -1 day") } - // Test remove() clear timeout timestamp, and setTimeoutDuration() is not allowed after that - state = new KeyedStateImpl(Some(5), 1000, isTimeoutEnabled = true, hasTimedOut = false) - state.remove() - assert(state.hasTimedOut === false) - assert(state.getTimeoutTimestamp === TIMEOUT_TIMESTAMP_NOT_SET) - testTimeoutNotAllowed() - - // Test hasTimedOut = true - state = new KeyedStateImpl(Some(5), 1000, isTimeoutEnabled = true, hasTimedOut = true) - assert(state.hasTimedOut === true) - assert(state.getTimeoutTimestamp === TIMEOUT_TIMESTAMP_NOT_SET) + state = new KeyedStateImpl(Some(5), 1000, 1000, EventTimeTimeout, hasTimedOut = false) + testIllegalTimeout { state.setTimeoutTimestamp(-10000) } + testIllegalTimeout { state.setTimeoutTimestamp(10000, "-3 second") } + testIllegalTimeout { state.setTimeoutTimestamp(10000, "-1 month") } + testIllegalTimeout { state.setTimeoutTimestamp(10000, "1 month -1 day") } + testIllegalTimeout { state.setTimeoutTimestamp(new Date(-10000)) } + testIllegalTimeout { state.setTimeoutTimestamp(new Date(-10000), "-3 second") } + testIllegalTimeout { state.setTimeoutTimestamp(new Date(-10000), "-1 month") } + testIllegalTimeout { state.setTimeoutTimestamp(new Date(-10000), "1 month -1 day") } + } + + test("KeyedState - hasTimedOut") { + for (timeoutConf <- Seq(NoTimeout, ProcessingTimeTimeout, EventTimeTimeout)) { + for (initState <- Seq(None, Some(5))) { + val state1 = new KeyedStateImpl(initState, 1000, 1000, timeoutConf, hasTimedOut = false) + assert(state1.hasTimedOut === false) + val state2 = new KeyedStateImpl(initState, 1000, 1000, timeoutConf, hasTimedOut = true) + assert(state2.hasTimedOut === true) + } + } } test("KeyedState - primitive type") { @@ -187,133 +204,186 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf } // Values used for testing StateStoreUpdater - val currentTimestamp = 1000 - val beforeCurrentTimestamp = 999 - val afterCurrentTimestamp = 1001 + val currentBatchTimestamp = 1000 + val currentBatchWatermark = 1000 + val beforeTimeoutThreshold = 999 + val afterTimeoutThreshold = 1001 + - // Tests for StateStoreUpdater.updateStateForKeysWithData() when timeout is disabled + // Tests for StateStoreUpdater.updateStateForKeysWithData() when timeout = NoTimeout for (priorState <- Seq(None, Some(0))) { val priorStateStr = if (priorState.nonEmpty) "prior state set" else "no prior state" - val testName = s"timeout disabled - $priorStateStr - " + val testName = s"NoTimeout - $priorStateStr - " testStateUpdateWithData( testName + "no update", stateUpdates = state => { /* do nothing */ }, - timeoutType = KeyedStateTimeout.NoTimeout, + timeoutConf = KeyedStateTimeout.NoTimeout, priorState = priorState, expectedState = priorState) // should not change testStateUpdateWithData( testName + "state updated", stateUpdates = state => { state.update(5) }, - timeoutType = KeyedStateTimeout.NoTimeout, + timeoutConf = KeyedStateTimeout.NoTimeout, priorState = priorState, expectedState = Some(5)) // should change testStateUpdateWithData( testName + "state removed", stateUpdates = state => { state.remove() }, - timeoutType = KeyedStateTimeout.NoTimeout, + timeoutConf = KeyedStateTimeout.NoTimeout, priorState = priorState, expectedState = None) // should be removed } - // Tests for StateStoreUpdater.updateStateForKeysWithData() when timeout is enabled + // Tests for StateStoreUpdater.updateStateForKeysWithData() when timeout != NoTimeout for (priorState <- Seq(None, Some(0))) { - for (priorTimeoutTimestamp <- Seq(TIMEOUT_TIMESTAMP_NOT_SET, 1000)) { - var testName = s"timeout enabled - " + for (priorTimeoutTimestamp <- Seq(NO_TIMESTAMP, 1000)) { + var testName = s"" if (priorState.nonEmpty) { testName += "prior state set, " if (priorTimeoutTimestamp == 1000) { - testName += "prior timeout set - " + testName += "prior timeout set" } else { - testName += "no prior timeout - " + testName += "no prior timeout" } } else { - testName += "no prior state - " + testName += "no prior state" + } + for (timeoutConf <- Seq(ProcessingTimeTimeout, EventTimeTimeout)) { + + testStateUpdateWithData( + s"$timeoutConf - $testName - no update", + stateUpdates = state => { /* do nothing */ }, + timeoutConf = timeoutConf, + priorState = priorState, + priorTimeoutTimestamp = priorTimeoutTimestamp, + expectedState = priorState, // state should not change + expectedTimeoutTimestamp = NO_TIMESTAMP) // timestamp should be reset + + testStateUpdateWithData( + s"$timeoutConf - $testName - state updated", + stateUpdates = state => { state.update(5) }, + timeoutConf = timeoutConf, + priorState = priorState, + priorTimeoutTimestamp = priorTimeoutTimestamp, + expectedState = Some(5), // state should change + expectedTimeoutTimestamp = NO_TIMESTAMP) // timestamp should be reset + + testStateUpdateWithData( + s"$timeoutConf - $testName - state removed", + stateUpdates = state => { state.remove() }, + timeoutConf = timeoutConf, + priorState = priorState, + priorTimeoutTimestamp = priorTimeoutTimestamp, + expectedState = None) // state should be removed } testStateUpdateWithData( - testName + "no update", - stateUpdates = state => { /* do nothing */ }, - timeoutType = KeyedStateTimeout.ProcessingTimeTimeout, - priorState = priorState, - priorTimeoutTimestamp = priorTimeoutTimestamp, - expectedState = priorState, // state should not change - expectedTimeoutTimestamp = TIMEOUT_TIMESTAMP_NOT_SET) // timestamp should be reset - - testStateUpdateWithData( - testName + "state updated", - stateUpdates = state => { state.update(5) }, - timeoutType = KeyedStateTimeout.ProcessingTimeTimeout, + s"ProcessingTimeTimeout - $testName - state and timeout duration updated", + stateUpdates = + (state: KeyedState[Int]) => { state.update(5); state.setTimeoutDuration(5000) }, + timeoutConf = ProcessingTimeTimeout, priorState = priorState, priorTimeoutTimestamp = priorTimeoutTimestamp, - expectedState = Some(5), // state should change - expectedTimeoutTimestamp = TIMEOUT_TIMESTAMP_NOT_SET) // timestamp should be reset + expectedState = Some(5), // state should change + expectedTimeoutTimestamp = currentBatchTimestamp + 5000) // timestamp should change testStateUpdateWithData( - testName + "state removed", - stateUpdates = state => { state.remove() }, - timeoutType = KeyedStateTimeout.ProcessingTimeTimeout, + s"EventTimeTimeout - $testName - state and timeout timestamp updated", + stateUpdates = + (state: KeyedState[Int]) => { state.update(5); state.setTimeoutTimestamp(5000) }, + timeoutConf = EventTimeTimeout, priorState = priorState, priorTimeoutTimestamp = priorTimeoutTimestamp, - expectedState = None) // state should be removed + expectedState = Some(5), // state should change + expectedTimeoutTimestamp = 5000) // timestamp should change testStateUpdateWithData( - testName + "timeout and state updated", - stateUpdates = state => { state.update(5); state.setTimeoutDuration(5000) }, - timeoutType = KeyedStateTimeout.ProcessingTimeTimeout, + s"EventTimeTimeout - $testName - timeout timestamp updated to before watermark", + stateUpdates = + (state: KeyedState[Int]) => { + state.update(5) + intercept[IllegalArgumentException] { + state.setTimeoutTimestamp(currentBatchWatermark - 1) // try to set to < watermark + } + }, + timeoutConf = EventTimeTimeout, priorState = priorState, priorTimeoutTimestamp = priorTimeoutTimestamp, - expectedState = Some(5), // state should change - expectedTimeoutTimestamp = currentTimestamp + 5000) // timestamp should change + expectedState = Some(5), // state should change + expectedTimeoutTimestamp = NO_TIMESTAMP) // timestamp should not update } } // Tests for StateStoreUpdater.updateStateForTimedOutKeys() val preTimeoutState = Some(5) + for (timeoutConf <- Seq(ProcessingTimeTimeout, EventTimeTimeout)) { + testStateUpdateWithTimeout( + s"$timeoutConf - should not timeout", + stateUpdates = state => { assert(false, "function called without timeout") }, + timeoutConf = timeoutConf, + priorTimeoutTimestamp = afterTimeoutThreshold, + expectedState = preTimeoutState, // state should not change + expectedTimeoutTimestamp = afterTimeoutThreshold) // timestamp should not change + + testStateUpdateWithTimeout( + s"$timeoutConf - should timeout - no update/remove", + stateUpdates = state => { /* do nothing */ }, + timeoutConf = timeoutConf, + priorTimeoutTimestamp = beforeTimeoutThreshold, + expectedState = preTimeoutState, // state should not change + expectedTimeoutTimestamp = NO_TIMESTAMP) // timestamp should be reset - testStateUpdateWithTimeout( - "should not timeout", - stateUpdates = state => { assert(false, "function called without timeout") }, - priorTimeoutTimestamp = afterCurrentTimestamp, - expectedState = preTimeoutState, // state should not change - expectedTimeoutTimestamp = afterCurrentTimestamp) // timestamp should not change + testStateUpdateWithTimeout( + s"$timeoutConf - should timeout - update state", + stateUpdates = state => { state.update(5) }, + timeoutConf = timeoutConf, + priorTimeoutTimestamp = beforeTimeoutThreshold, + expectedState = Some(5), // state should change + expectedTimeoutTimestamp = NO_TIMESTAMP) // timestamp should be reset + + testStateUpdateWithTimeout( + s"$timeoutConf - should timeout - remove state", + stateUpdates = state => { state.remove() }, + timeoutConf = timeoutConf, + priorTimeoutTimestamp = beforeTimeoutThreshold, + expectedState = None, // state should be removed + expectedTimeoutTimestamp = NO_TIMESTAMP) + } testStateUpdateWithTimeout( - "should timeout - no update/remove", - stateUpdates = state => { /* do nothing */ }, - priorTimeoutTimestamp = beforeCurrentTimestamp, + "ProcessingTimeTimeout - should timeout - timeout duration updated", + stateUpdates = state => { state.setTimeoutDuration(2000) }, + timeoutConf = ProcessingTimeTimeout, + priorTimeoutTimestamp = beforeTimeoutThreshold, expectedState = preTimeoutState, // state should not change - expectedTimeoutTimestamp = TIMEOUT_TIMESTAMP_NOT_SET) // timestamp should be reset + expectedTimeoutTimestamp = currentBatchTimestamp + 2000) // timestamp should change testStateUpdateWithTimeout( - "should timeout - update state", - stateUpdates = state => { state.update(5) }, - priorTimeoutTimestamp = beforeCurrentTimestamp, + "ProcessingTimeTimeout - should timeout - timeout duration and state updated", + stateUpdates = state => { state.update(5); state.setTimeoutDuration(2000) }, + timeoutConf = ProcessingTimeTimeout, + priorTimeoutTimestamp = beforeTimeoutThreshold, expectedState = Some(5), // state should change - expectedTimeoutTimestamp = TIMEOUT_TIMESTAMP_NOT_SET) // timestamp should be reset + expectedTimeoutTimestamp = currentBatchTimestamp + 2000) // timestamp should change testStateUpdateWithTimeout( - "should timeout - remove state", - stateUpdates = state => { state.remove() }, - priorTimeoutTimestamp = beforeCurrentTimestamp, - expectedState = None, // state should be removed - expectedTimeoutTimestamp = TIMEOUT_TIMESTAMP_NOT_SET) - - testStateUpdateWithTimeout( - "should timeout - timeout updated", - stateUpdates = state => { state.setTimeoutDuration(2000) }, - priorTimeoutTimestamp = beforeCurrentTimestamp, + "EventTimeTimeout - should timeout - timeout timestamp updated", + stateUpdates = state => { state.setTimeoutTimestamp(5000) }, + timeoutConf = EventTimeTimeout, + priorTimeoutTimestamp = beforeTimeoutThreshold, expectedState = preTimeoutState, // state should not change - expectedTimeoutTimestamp = currentTimestamp + 2000) // timestamp should change + expectedTimeoutTimestamp = 5000) // timestamp should change testStateUpdateWithTimeout( - "should timeout - timeout and state updated", - stateUpdates = state => { state.update(5); state.setTimeoutDuration(2000) }, - priorTimeoutTimestamp = beforeCurrentTimestamp, + "EventTimeTimeout - should timeout - timeout and state updated", + stateUpdates = state => { state.update(5); state.setTimeoutTimestamp(5000) }, + timeoutConf = EventTimeTimeout, + priorTimeoutTimestamp = beforeTimeoutThreshold, expectedState = Some(5), // state should change - expectedTimeoutTimestamp = currentTimestamp + 2000) // timestamp should change + expectedTimeoutTimestamp = 5000) // timestamp should change test("StateStoreUpdater - rows are cloned before writing to StateStore") { // function for running count @@ -481,11 +551,10 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf val clock = new StreamManualClock val inputData = MemoryStream[String] - val timeout = KeyedStateTimeout.ProcessingTimeTimeout val result = inputData.toDS() .groupByKey(x => x) - .flatMapGroupsWithState(Update, timeout)(stateFunc) + .flatMapGroupsWithState(Update, ProcessingTimeTimeout)(stateFunc) testStream(result, Update)( StartStream(ProcessingTime("1 second"), triggerClock = clock), @@ -519,6 +588,52 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf ) } + test("flatMapGroupsWithState - streaming with event time timeout") { + // Function to maintain the max event time + // Returns the max event time in the state, or -1 if the state was removed by timeout + val stateFunc = ( + key: String, + values: Iterator[(String, Long)], + state: KeyedState[Long]) => { + val timeoutDelay = 5 + if (key != "a") { + Iterator.empty + } else { + if (state.hasTimedOut) { + state.remove() + Iterator((key, -1)) + } else { + val valuesSeq = values.toSeq + val maxEventTime = math.max(valuesSeq.map(_._2).max, state.getOption.getOrElse(0L)) + val timeoutTimestampMs = maxEventTime + timeoutDelay + state.update(maxEventTime) + state.setTimeoutTimestamp(timeoutTimestampMs * 1000) + Iterator((key, maxEventTime.toInt)) + } + } + } + val inputData = MemoryStream[(String, Int)] + val result = + inputData.toDS + .select($"_1".as("key"), $"_2".cast("timestamp").as("eventTime")) + .withWatermark("eventTime", "10 seconds") + .as[(String, Long)] + .groupByKey(_._1) + .flatMapGroupsWithState(Update, EventTimeTimeout)(stateFunc) + + testStream(result, Update)( + StartStream(ProcessingTime("1 second")), + AddData(inputData, ("a", 11), ("a", 13), ("a", 15)), // Set timeout timestamp of ... + CheckLastBatch(("a", 15)), // "a" to 15 + 5 = 20s, watermark to 5s + AddData(inputData, ("a", 4)), // Add data older than watermark for "a" + CheckLastBatch(), // No output as data should get filtered by watermark + AddData(inputData, ("dummy", 35)), // Set watermark = 35 - 10 = 25s + CheckLastBatch(), // No output as no data for "a" + AddData(inputData, ("a", 24)), // Add data older than watermark, should be ignored + CheckLastBatch(("a", -1)) // State for "a" should timeout and emit -1 + ) + } + test("mapGroupsWithState - streaming") { // Function to maintain running count up to 2, and then remove the count // Returns the data and the count (-1 if count reached beyond 2 and state was just removed) @@ -612,7 +727,6 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => key val inputData = MemoryStream[String] val result = inputData.toDS.groupByKey(x => x).mapGroupsWithState(stateFunc) - result testStream(result, Update)( AddData(inputData, "a"), CheckLastBatch("a"), @@ -649,13 +763,13 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf def testStateUpdateWithData( testName: String, stateUpdates: KeyedState[Int] => Unit, - timeoutType: KeyedStateTimeout = KeyedStateTimeout.NoTimeout, + timeoutConf: KeyedStateTimeout, priorState: Option[Int], - priorTimeoutTimestamp: Long = TIMEOUT_TIMESTAMP_NOT_SET, + priorTimeoutTimestamp: Long = NO_TIMESTAMP, expectedState: Option[Int] = None, - expectedTimeoutTimestamp: Long = TIMEOUT_TIMESTAMP_NOT_SET): Unit = { + expectedTimeoutTimestamp: Long = NO_TIMESTAMP): Unit = { - if (priorState.isEmpty && priorTimeoutTimestamp != TIMEOUT_TIMESTAMP_NOT_SET) { + if (priorState.isEmpty && priorTimeoutTimestamp != NO_TIMESTAMP) { return // there can be no prior timestamp, when there is no prior state } test(s"StateStoreUpdater - updates with data - $testName") { @@ -666,7 +780,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf Iterator.empty } testStateUpdate( - testTimeoutUpdates = false, mapGroupsFunc, timeoutType, + testTimeoutUpdates = false, mapGroupsFunc, timeoutConf, priorState, priorTimeoutTimestamp, expectedState, expectedTimeoutTimestamp) } } @@ -674,9 +788,10 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf def testStateUpdateWithTimeout( testName: String, stateUpdates: KeyedState[Int] => Unit, + timeoutConf: KeyedStateTimeout, priorTimeoutTimestamp: Long, expectedState: Option[Int], - expectedTimeoutTimestamp: Long = TIMEOUT_TIMESTAMP_NOT_SET): Unit = { + expectedTimeoutTimestamp: Long = NO_TIMESTAMP): Unit = { test(s"StateStoreUpdater - updates for timeout - $testName") { val mapGroupsFunc = (key: Int, values: Iterator[Int], state: KeyedState[Int]) => { @@ -686,16 +801,15 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf Iterator.empty } testStateUpdate( - testTimeoutUpdates = true, mapGroupsFunc, KeyedStateTimeout.ProcessingTimeTimeout, - preTimeoutState, priorTimeoutTimestamp, - expectedState, expectedTimeoutTimestamp) + testTimeoutUpdates = true, mapGroupsFunc, timeoutConf = timeoutConf, + preTimeoutState, priorTimeoutTimestamp, expectedState, expectedTimeoutTimestamp) } } def testStateUpdate( testTimeoutUpdates: Boolean, mapGroupsFunc: (Int, Iterator[Int], KeyedState[Int]) => Iterator[Int], - timeoutType: KeyedStateTimeout, + timeoutConf: KeyedStateTimeout, priorState: Option[Int], priorTimeoutTimestamp: Long, expectedState: Option[Int], @@ -703,7 +817,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf val store = newStateStore() val mapGroupsSparkPlan = newFlatMapGroupsWithStateExec( - mapGroupsFunc, timeoutType, currentTimestamp) + mapGroupsFunc, timeoutConf, currentBatchTimestamp) val updater = new mapGroupsSparkPlan.StateStoreUpdater(store) val key = intToRow(0) // Prepare store with prior state configs @@ -736,7 +850,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf def newFlatMapGroupsWithStateExec( func: (Int, Iterator[Int], KeyedState[Int]) => Iterator[Int], timeoutType: KeyedStateTimeout = KeyedStateTimeout.NoTimeout, - batchTimestampMs: Long = NO_BATCH_PROCESSING_TIMESTAMP): FlatMapGroupsWithStateExec = { + batchTimestampMs: Long = NO_TIMESTAMP): FlatMapGroupsWithStateExec = { MemoryStream[Int] .toDS .groupByKey(x => x) @@ -744,11 +858,31 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf .logicalPlan.collectFirst { case FlatMapGroupsWithState(f, k, v, g, d, o, s, m, _, t, _) => FlatMapGroupsWithStateExec( - f, k, v, g, d, o, None, s, m, t, currentTimestamp, - RDDScanExec(g, null, "rdd")) + f, k, v, g, d, o, None, s, m, t, + Some(currentBatchTimestamp), Some(currentBatchWatermark), RDDScanExec(g, null, "rdd")) }.get } + def testTimeoutDurationNotAllowed[T <: Exception: Manifest](state: KeyedStateImpl[_]): Unit = { + val prevTimestamp = state.getTimeoutTimestamp + intercept[T] { state.setTimeoutDuration(1000) } + assert(state.getTimeoutTimestamp === prevTimestamp) + intercept[T] { state.setTimeoutDuration("2 second") } + assert(state.getTimeoutTimestamp === prevTimestamp) + } + + def testTimeoutTimestampNotAllowed[T <: Exception: Manifest](state: KeyedStateImpl[_]): Unit = { + val prevTimestamp = state.getTimeoutTimestamp + intercept[T] { state.setTimeoutTimestamp(2000) } + assert(state.getTimeoutTimestamp === prevTimestamp) + intercept[T] { state.setTimeoutTimestamp(2000, "1 second") } + assert(state.getTimeoutTimestamp === prevTimestamp) + intercept[T] { state.setTimeoutTimestamp(new Date(2000)) } + assert(state.getTimeoutTimestamp === prevTimestamp) + intercept[T] { state.setTimeoutTimestamp(new Date(2000), "1 second") } + assert(state.getTimeoutTimestamp === prevTimestamp) + } + def newStateStore(): StateStore = new MemoryStateStore() val intProj = UnsafeProjection.create(Array[DataType](IntegerType)) From 478fbc866fbfdb4439788583281863ecea14e8af Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Tue, 21 Mar 2017 21:50:54 -0700 Subject: [PATCH 099/512] [SPARK-19925][SPARKR] Fix SparkR spark.getSparkFiles fails when it was called on executors. ## What changes were proposed in this pull request? SparkR ```spark.getSparkFiles``` fails when it was called on executors, see details at [SPARK-19925](https://issues.apache.org/jira/browse/SPARK-19925). ## How was this patch tested? Add unit tests, and verify this fix at standalone and yarn cluster. Author: Yanbo Liang Closes #17274 from yanboliang/spark-19925. --- R/pkg/R/context.R | 16 ++++++++++++++-- R/pkg/inst/tests/testthat/test_context.R | 7 +++++++ .../scala/org/apache/spark/api/r/RRunner.scala | 2 ++ 3 files changed, 23 insertions(+), 2 deletions(-) diff --git a/R/pkg/R/context.R b/R/pkg/R/context.R index 1ca573e5bd61..50856e3d9856 100644 --- a/R/pkg/R/context.R +++ b/R/pkg/R/context.R @@ -330,7 +330,13 @@ spark.addFile <- function(path, recursive = FALSE) { #'} #' @note spark.getSparkFilesRootDirectory since 2.1.0 spark.getSparkFilesRootDirectory <- function() { - callJStatic("org.apache.spark.SparkFiles", "getRootDirectory") + if (Sys.getenv("SPARKR_IS_RUNNING_ON_WORKER") == "") { + # Running on driver. + callJStatic("org.apache.spark.SparkFiles", "getRootDirectory") + } else { + # Running on worker. + Sys.getenv("SPARKR_SPARKFILES_ROOT_DIR") + } } #' Get the absolute path of a file added through spark.addFile. @@ -345,7 +351,13 @@ spark.getSparkFilesRootDirectory <- function() { #'} #' @note spark.getSparkFiles since 2.1.0 spark.getSparkFiles <- function(fileName) { - callJStatic("org.apache.spark.SparkFiles", "get", as.character(fileName)) + if (Sys.getenv("SPARKR_IS_RUNNING_ON_WORKER") == "") { + # Running on driver. + callJStatic("org.apache.spark.SparkFiles", "get", as.character(fileName)) + } else { + # Running on worker. + file.path(spark.getSparkFilesRootDirectory(), as.character(fileName)) + } } #' Run a function over a list of elements, distributing the computations with Spark diff --git a/R/pkg/inst/tests/testthat/test_context.R b/R/pkg/inst/tests/testthat/test_context.R index caca06933952..c84711349111 100644 --- a/R/pkg/inst/tests/testthat/test_context.R +++ b/R/pkg/inst/tests/testthat/test_context.R @@ -177,6 +177,13 @@ test_that("add and get file to be downloaded with Spark job on every node", { spark.addFile(path) download_path <- spark.getSparkFiles(filename) expect_equal(readLines(download_path), words) + + # Test spark.getSparkFiles works well on executors. + seq <- seq(from = 1, to = 10, length.out = 5) + f <- function(seq) { spark.getSparkFiles(filename) } + results <- spark.lapply(seq, f) + for (i in 1:5) { expect_equal(basename(results[[i]]), filename) } + unlink(path) # Test add directory recursively. diff --git a/core/src/main/scala/org/apache/spark/api/r/RRunner.scala b/core/src/main/scala/org/apache/spark/api/r/RRunner.scala index 29e21b3b1aa8..88118392003e 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RRunner.scala @@ -347,6 +347,8 @@ private[r] object RRunner { pb.environment().put("SPARKR_RLIBDIR", rLibDir.mkString(",")) pb.environment().put("SPARKR_WORKER_PORT", port.toString) pb.environment().put("SPARKR_BACKEND_CONNECTION_TIMEOUT", rConnectionTimeout.toString) + pb.environment().put("SPARKR_SPARKFILES_ROOT_DIR", SparkFiles.getRootDirectory()) + pb.environment().put("SPARKR_IS_RUNNING_ON_WORKER", "TRUE") pb.redirectErrorStream(true) // redirect stderr into stdout val proc = pb.start() val errThread = startStdoutThread(proc) From 7343a09401e7d6636634968b1cd8bc403a1f77b6 Mon Sep 17 00:00:00 2001 From: Xiao Li Date: Wed, 22 Mar 2017 19:08:28 +0800 Subject: [PATCH 100/512] [SPARK-20023][SQL] Output table comment for DESC FORMATTED ### What changes were proposed in this pull request? Currently, `DESC FORMATTED` did not output the table comment, unlike what `DESC EXTENDED` does. This PR is to fix it. Also correct the following displayed names in `DESC FORMATTED`, for being consistent with `DESC EXTENDED` - `"Create Time:"` -> `"Created:"` - `"Last Access Time:"` -> `"Last Access:"` ### How was this patch tested? Added test cases in `describe.sql` Author: Xiao Li Closes #17381 from gatorsmile/descFormattedTableComment. --- .../spark/sql/execution/command/tables.scala | 5 +- .../resources/sql-tests/inputs/describe.sql | 14 +- .../sql-tests/results/describe.sql.out | 125 ++++++++++++++++-- .../apache/spark/sql/SQLQueryTestSuite.scala | 6 +- 4 files changed, 124 insertions(+), 26 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala index 93307fc88356..c7aeef06a0bf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala @@ -568,11 +568,12 @@ case class DescribeTableCommand( append(buffer, "# Detailed Table Information", "", "") append(buffer, "Database:", table.database, "") append(buffer, "Owner:", table.owner, "") - append(buffer, "Create Time:", new Date(table.createTime).toString, "") - append(buffer, "Last Access Time:", new Date(table.lastAccessTime).toString, "") + append(buffer, "Created:", new Date(table.createTime).toString, "") + append(buffer, "Last Access:", new Date(table.lastAccessTime).toString, "") append(buffer, "Location:", table.storage.locationUri.map(CatalogUtils.URIToString(_)) .getOrElse(""), "") append(buffer, "Table Type:", table.tableType.name, "") + append(buffer, "Comment:", table.comment.getOrElse(""), "") table.stats.foreach(s => append(buffer, "Statistics:", s.simpleString, "")) append(buffer, "Table Parameters:", "", "") diff --git a/sql/core/src/test/resources/sql-tests/inputs/describe.sql b/sql/core/src/test/resources/sql-tests/inputs/describe.sql index ff327f5e82b1..56f3281440d2 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/describe.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/describe.sql @@ -1,4 +1,4 @@ -CREATE TABLE t (a STRING, b INT, c STRING, d STRING) USING parquet PARTITIONED BY (c, d); +CREATE TABLE t (a STRING, b INT, c STRING, d STRING) USING parquet PARTITIONED BY (c, d) COMMENT 'table_comment'; ALTER TABLE t ADD PARTITION (c='Us', d=1); @@ -8,15 +8,15 @@ DESC t; DESC TABLE t; --- Ignore these because there exist timestamp results, e.g., `Create Table`. --- DESC EXTENDED t; --- DESC FORMATTED t; +DESC FORMATTED t; + +DESC EXTENDED t; DESC t PARTITION (c='Us', d=1); --- Ignore these because there exist timestamp results, e.g., transient_lastDdlTime. --- DESC EXTENDED t PARTITION (c='Us', d=1); --- DESC FORMATTED t PARTITION (c='Us', d=1); +DESC EXTENDED t PARTITION (c='Us', d=1); + +DESC FORMATTED t PARTITION (c='Us', d=1); -- NoSuchPartitionException: Partition not found in table DESC t PARTITION (c='Us', d=2); diff --git a/sql/core/src/test/resources/sql-tests/results/describe.sql.out b/sql/core/src/test/resources/sql-tests/results/describe.sql.out index 0a11c1cde2b4..422d548ea8de 100644 --- a/sql/core/src/test/resources/sql-tests/results/describe.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/describe.sql.out @@ -1,9 +1,9 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 10 +-- Number of queries: 14 -- !query 0 -CREATE TABLE t (a STRING, b INT, c STRING, d STRING) USING parquet PARTITIONED BY (c, d) +CREATE TABLE t (a STRING, b INT, c STRING, d STRING) USING parquet PARTITIONED BY (c, d) COMMENT 'table_comment' -- !query 0 schema struct<> -- !query 0 output @@ -64,12 +64,25 @@ d string -- !query 5 -DESC t PARTITION (c='Us', d=1) +DESC FORMATTED t -- !query 5 schema struct -- !query 5 output +# Detailed Table Information # Partition Information +# Storage Information # col_name data_type comment +Comment: table_comment +Compressed: No +Created: +Database: default +Last Access: +Location: sql/core/spark-warehouse/t +Owner: +Partition Provider: Catalog +Storage Desc Parameters: +Table Parameters: +Table Type: MANAGED a string b int c string @@ -79,30 +92,114 @@ d string -- !query 6 -DESC t PARTITION (c='Us', d=2) +DESC EXTENDED t -- !query 6 schema -struct<> +struct -- !query 6 output +# Detailed Table Information CatalogTable( + Table: `default`.`t` + Created: + Last Access: + Type: MANAGED + Schema: [StructField(a,StringType,true), StructField(b,IntegerType,true), StructField(c,StringType,true), StructField(d,StringType,true)] + Provider: parquet + Partition Columns: [`c`, `d`] + Comment: table_comment + Storage(Location: sql/core/spark-warehouse/t) + Partition Provider: Catalog) +# Partition Information +# col_name data_type comment +a string +b int +c string +c string +d string +d string + + +-- !query 7 +DESC t PARTITION (c='Us', d=1) +-- !query 7 schema +struct +-- !query 7 output +# Partition Information +# col_name data_type comment +a string +b int +c string +c string +d string +d string + + +-- !query 8 +DESC EXTENDED t PARTITION (c='Us', d=1) +-- !query 8 schema +struct +-- !query 8 output +# Partition Information +# col_name data_type comment +Detailed Partition Information CatalogPartition( + Partition Values: [c=Us, d=1] + Storage(Location: sql/core/spark-warehouse/t/c=Us/d=1) + Partition Parameters:{}) +a string +b int +c string +c string +d string +d string + + +-- !query 9 +DESC FORMATTED t PARTITION (c='Us', d=1) +-- !query 9 schema +struct +-- !query 9 output +# Detailed Partition Information +# Partition Information +# Storage Information +# col_name data_type comment +Compressed: No +Database: default +Location: sql/core/spark-warehouse/t/c=Us/d=1 +Partition Parameters: +Partition Value: [Us, 1] +Storage Desc Parameters: +Table: t +a string +b int +c string +c string +d string +d string + + +-- !query 10 +DESC t PARTITION (c='Us', d=2) +-- !query 10 schema +struct<> +-- !query 10 output org.apache.spark.sql.catalyst.analysis.NoSuchPartitionException Partition not found in table 't' database 'default': c -> Us d -> 2; --- !query 7 +-- !query 11 DESC t PARTITION (c='Us') --- !query 7 schema +-- !query 11 schema struct<> --- !query 7 output +-- !query 11 output org.apache.spark.sql.AnalysisException Partition spec is invalid. The spec (c) must match the partition spec (c, d) defined in table '`default`.`t`'; --- !query 8 +-- !query 12 DESC t PARTITION (c='Us', d) --- !query 8 schema +-- !query 12 schema struct<> --- !query 8 output +-- !query 12 output org.apache.spark.sql.catalyst.parser.ParseException PARTITION specification is incomplete: `d`(line 1, pos 0) @@ -112,9 +209,9 @@ DESC t PARTITION (c='Us', d) ^^^ --- !query 9 +-- !query 13 DROP TABLE t --- !query 9 schema +-- !query 13 schema struct<> --- !query 9 output +-- !query 13 output diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala index c285995514c8..4092862c430b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala @@ -223,9 +223,9 @@ class SQLQueryTestSuite extends QueryTest with SharedSQLContext { val schema = df.schema // Get answer, but also get rid of the #1234 expression ids that show up in explain plans val answer = df.queryExecution.hiveResultString().map(_.replaceAll("#\\d+", "#x") - .replaceAll("Location: .*/sql/core/", "Location: sql/core/") - .replaceAll("Created: .*\n", "Created: \n") - .replaceAll("Last Access: .*\n", "Last Access: \n")) + .replaceAll("Location:.*/sql/core/", "Location: sql/core/") + .replaceAll("Created: .*", "Created: ") + .replaceAll("Last Access: .*", "Last Access: ")) // If the output is not pre-sorted, sort it. if (isSorted(df.queryExecution.analyzed)) (schema, answer) else (schema, answer.sorted) From facfd608865c385c0dabfe09cffe5874532a9cdf Mon Sep 17 00:00:00 2001 From: uncleGen Date: Wed, 22 Mar 2017 11:10:08 +0000 Subject: [PATCH 101/512] [SPARK-20021][PYSPARK] Miss backslash in python code ## What changes were proposed in this pull request? Add backslash for line continuation in python code. ## How was this patch tested? Jenkins. Author: uncleGen Author: dylon Closes #17352 from uncleGen/python-example-doc. --- docs/structured-streaming-programming-guide.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/structured-streaming-programming-guide.md b/docs/structured-streaming-programming-guide.md index 798847237866..ff07ad11943b 100644 --- a/docs/structured-streaming-programming-guide.md +++ b/docs/structured-streaming-programming-guide.md @@ -764,11 +764,11 @@ Dataset windowedCounts = words words = ... # streaming DataFrame of schema { timestamp: Timestamp, word: String } # Group the data by window and word and compute the count of each group -windowedCounts = words - .withWatermark("timestamp", "10 minutes") +windowedCounts = words \ + .withWatermark("timestamp", "10 minutes") \ .groupBy( window(words.timestamp, "10 minutes", "5 minutes"), - words.word) + words.word) \ .count() {% endhighlight %} From 0caade634076034182e22318eb09a6df1c560576 Mon Sep 17 00:00:00 2001 From: Prashant Sharma Date: Wed, 22 Mar 2017 13:52:03 +0000 Subject: [PATCH 102/512] [SPARK-20027][DOCS] Compilation fix in java docs. ## What changes were proposed in this pull request? During build/sbt publish-local, build breaks due to javadocs errors. This patch fixes those errors. ## How was this patch tested? Tested by running the sbt build. Author: Prashant Sharma Closes #17358 from ScrapCodes/docs-fix. --- .../java/org/apache/spark/network/crypto/ClientChallenge.java | 2 +- .../java/org/apache/spark/network/crypto/ServerResponse.java | 2 +- .../main/java/org/apache/spark/unsafe/types/UTF8String.java | 2 +- .../api/java/function/FlatMapGroupsWithStateFunction.java | 3 ++- 4 files changed, 5 insertions(+), 4 deletions(-) diff --git a/common/network-common/src/main/java/org/apache/spark/network/crypto/ClientChallenge.java b/common/network-common/src/main/java/org/apache/spark/network/crypto/ClientChallenge.java index 3312a5bd81a6..819b8a7efbdb 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/crypto/ClientChallenge.java +++ b/common/network-common/src/main/java/org/apache/spark/network/crypto/ClientChallenge.java @@ -28,7 +28,7 @@ /** * The client challenge message, used to initiate authentication. * - * @see README.md + * Please see crypto/README.md for more details of implementation. */ public class ClientChallenge implements Encodable { /** Serialization tag used to catch incorrect payloads. */ diff --git a/common/network-common/src/main/java/org/apache/spark/network/crypto/ServerResponse.java b/common/network-common/src/main/java/org/apache/spark/network/crypto/ServerResponse.java index affdbf450b1d..caf3a0f3b38c 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/crypto/ServerResponse.java +++ b/common/network-common/src/main/java/org/apache/spark/network/crypto/ServerResponse.java @@ -28,7 +28,7 @@ /** * Server's response to client's challenge. * - * @see README.md + * Please see crypto/README.md for more details. */ public class ServerResponse implements Encodable { /** Serialization tag used to catch incorrect payloads. */ diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index 4c28075bd938..5437e998c085 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -863,7 +863,7 @@ public static class LongWrapper { * This is done solely for better performance and is not expected to be used by end users. * * {@link LongWrapper} could have been used here but using `int` directly save the extra cost of - * conversion from `long` -> `int` + * conversion from `long` to `int` */ public static class IntWrapper { public int value = 0; diff --git a/sql/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsWithStateFunction.java b/sql/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsWithStateFunction.java index 29af78c4f6a8..bdda8aaf734d 100644 --- a/sql/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsWithStateFunction.java +++ b/sql/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsWithStateFunction.java @@ -28,7 +28,8 @@ * ::Experimental:: * Base interface for a map function used in * {@link org.apache.spark.sql.KeyValueGroupedDataset#flatMapGroupsWithState( - * FlatMapGroupsWithStateFunction, org.apache.spark.sql.Encoder, org.apache.spark.sql.Encoder)}. + * FlatMapGroupsWithStateFunction, org.apache.spark.sql.streaming.OutputMode, + * org.apache.spark.sql.Encoder, org.apache.spark.sql.Encoder)} * @since 2.1.1 */ @Experimental From 465818389aab1217c9de5c685cfaee3ffaec91bb Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Wed, 22 Mar 2017 09:52:37 -0700 Subject: [PATCH 103/512] [SPARK-19949][SQL][FOLLOW-UP] Clean up parse modes and update related comments ## What changes were proposed in this pull request? This PR proposes to make `mode` options in both CSV and JSON to use `cass object` and fix some related comments related previous fix. Also, this PR modifies some tests related parse modes. ## How was this patch tested? Modified unit tests in both `CSVSuite.scala` and `JsonSuite.scala`. Author: hyukjinkwon Closes #17377 from HyukjinKwon/SPARK-19949. --- python/pyspark/sql/readwriter.py | 6 +- python/pyspark/sql/streaming.py | 2 + .../expressions/jsonExpressions.scala | 4 +- .../spark/sql/catalyst/json/JSONOptions.scala | 12 +- .../sql/catalyst/util/FailureSafeParser.scala | 15 ++- .../spark/sql/catalyst/util/ParseMode.scala | 56 +++++++++ .../spark/sql/catalyst/util/ParseModes.scala | 41 ------- .../expressions/JsonExpressionsSuite.scala | 4 +- .../apache/spark/sql/DataFrameReader.scala | 4 +- .../datasources/csv/CSVOptions.scala | 14 +-- .../datasources/json/JsonInferSchema.scala | 3 +- .../sql/streaming/DataStreamReader.scala | 2 +- .../execution/datasources/csv/CSVSuite.scala | 7 +- .../datasources/json/JsonSuite.scala | 113 +++++++----------- 14 files changed, 130 insertions(+), 153 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ParseMode.scala delete mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ParseModes.scala diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 122e17f2020f..759c27507c39 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -369,10 +369,8 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non :param maxCharsPerColumn: defines the maximum number of characters allowed for any given value being read. If None is set, it uses the default value, ``-1`` meaning unlimited length. - :param maxMalformedLogPerPartition: sets the maximum number of malformed rows Spark will - log for each partition. Malformed records beyond this - number will be ignored. If None is set, it - uses the default value, ``10``. + :param maxMalformedLogPerPartition: this parameter is no longer used since Spark 2.2.0. + If specified, it is ignored. :param mode: allows a mode for dealing with corrupt records during parsing. If None is set, it uses the default value, ``PERMISSIVE``. diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index 288cc1e4f64d..e227f9ceb576 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -625,6 +625,8 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non :param maxCharsPerColumn: defines the maximum number of characters allowed for any given value being read. If None is set, it uses the default value, ``-1`` meaning unlimited length. + :param maxMalformedLogPerPartition: this parameter is no longer used since Spark 2.2.0. + If specified, it is ignored. :param mode: allows a mode for dealing with corrupt records during parsing. If None is set, it uses the default value, ``PERMISSIVE``. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index 08af5522d822..df4d406b84d6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.json._ -import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, BadRecordException, GenericArrayData, ParseModes} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, BadRecordException, FailFastMode, GenericArrayData} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils @@ -548,7 +548,7 @@ case class JsonToStructs( lazy val parser = new JacksonParser( rowSchema, - new JSONOptions(options + ("mode" -> ParseModes.FAIL_FAST_MODE), timeZoneId.get)) + new JSONOptions(options + ("mode" -> FailFastMode.name), timeZoneId.get)) override def dataType: DataType = schema diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala index 355c26afa6f0..c22b1ade4e64 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala @@ -65,7 +65,8 @@ private[sql] class JSONOptions( val allowBackslashEscapingAnyCharacter = parameters.get("allowBackslashEscapingAnyCharacter").map(_.toBoolean).getOrElse(false) val compressionCodec = parameters.get("compression").map(CompressionCodecs.getCodecClassName) - val parseMode = parameters.getOrElse("mode", "PERMISSIVE") + val parseMode: ParseMode = + parameters.get("mode").map(ParseMode.fromString).getOrElse(PermissiveMode) val columnNameOfCorruptRecord = parameters.getOrElse("columnNameOfCorruptRecord", defaultColumnNameOfCorruptRecord) @@ -82,15 +83,6 @@ private[sql] class JSONOptions( val wholeFile = parameters.get("wholeFile").map(_.toBoolean).getOrElse(false) - // Parse mode flags - if (!ParseModes.isValidMode(parseMode)) { - logWarning(s"$parseMode is not a valid parse mode. Using ${ParseModes.DEFAULT}.") - } - - val failFast = ParseModes.isFailFastMode(parseMode) - val dropMalformed = ParseModes.isDropMalformedMode(parseMode) - val permissive = ParseModes.isPermissiveMode(parseMode) - /** Sets config options on a Jackson [[JsonFactory]]. */ def setJacksonOptions(factory: JsonFactory): Unit = { factory.configure(JsonParser.Feature.ALLOW_COMMENTS, allowComments) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/FailureSafeParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/FailureSafeParser.scala index e8da10d65ecb..725e3015b341 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/FailureSafeParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/FailureSafeParser.scala @@ -24,7 +24,7 @@ import org.apache.spark.unsafe.types.UTF8String class FailureSafeParser[IN]( rawParser: IN => Seq[InternalRow], - mode: String, + mode: ParseMode, schema: StructType, columnNameOfCorruptRecord: String) { @@ -58,11 +58,14 @@ class FailureSafeParser[IN]( try { rawParser.apply(input).toIterator.map(row => toResultRow(Some(row), () => null)) } catch { - case e: BadRecordException if ParseModes.isPermissiveMode(mode) => - Iterator(toResultRow(e.partialResult(), e.record)) - case _: BadRecordException if ParseModes.isDropMalformedMode(mode) => - Iterator.empty - case e: BadRecordException => throw e.cause + case e: BadRecordException => mode match { + case PermissiveMode => + Iterator(toResultRow(e.partialResult(), e.record)) + case DropMalformedMode => + Iterator.empty + case FailFastMode => + throw e.cause + } } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ParseMode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ParseMode.scala new file mode 100644 index 000000000000..4565dbde88c8 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ParseMode.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util + +import org.apache.spark.internal.Logging + +sealed trait ParseMode { + /** + * String name of the parse mode. + */ + def name: String +} + +/** + * This mode permissively parses the records. + */ +case object PermissiveMode extends ParseMode { val name = "PERMISSIVE" } + +/** + * This mode ignores the whole corrupted records. + */ +case object DropMalformedMode extends ParseMode { val name = "DROPMALFORMED" } + +/** + * This mode throws an exception when it meets corrupted records. + */ +case object FailFastMode extends ParseMode { val name = "FAILFAST" } + +object ParseMode extends Logging { + /** + * Returns the parse mode from the given string. + */ + def fromString(mode: String): ParseMode = mode.toUpperCase match { + case PermissiveMode.name => PermissiveMode + case DropMalformedMode.name => DropMalformedMode + case FailFastMode.name => FailFastMode + case _ => + logWarning(s"$mode is not a valid parse mode. Using ${PermissiveMode.name}.") + PermissiveMode + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ParseModes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ParseModes.scala deleted file mode 100644 index 0e466962b467..000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ParseModes.scala +++ /dev/null @@ -1,41 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.util - -object ParseModes { - val PERMISSIVE_MODE = "PERMISSIVE" - val DROP_MALFORMED_MODE = "DROPMALFORMED" - val FAIL_FAST_MODE = "FAILFAST" - - val DEFAULT = PERMISSIVE_MODE - - def isValidMode(mode: String): Boolean = { - mode.toUpperCase match { - case PERMISSIVE_MODE | DROP_MALFORMED_MODE | FAIL_FAST_MODE => true - case _ => false - } - } - - def isDropMalformedMode(mode: String): Boolean = mode.toUpperCase == DROP_MALFORMED_MODE - def isFailFastMode(mode: String): Boolean = mode.toUpperCase == FAIL_FAST_MODE - def isPermissiveMode(mode: String): Boolean = if (isValidMode(mode)) { - mode.toUpperCase == PERMISSIVE_MODE - } else { - true // We default to permissive is the mode string is not valid - } -} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala index e4698d44636b..c5b72235e5db 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala @@ -21,7 +21,7 @@ import java.util.Calendar import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.util.{DateTimeTestUtils, DateTimeUtils, GenericArrayData, ParseModes} +import org.apache.spark.sql.catalyst.util.{DateTimeTestUtils, DateTimeUtils, GenericArrayData, PermissiveMode} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -367,7 +367,7 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { // Other modes should still return `null`. checkEvaluation( - JsonToStructs(schema, Map("mode" -> ParseModes.PERMISSIVE_MODE), Literal(jsonData), gmtId), + JsonToStructs(schema, Map("mode" -> PermissiveMode.name), Literal(jsonData), gmtId), null ) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 767a636d7073..e39b4d91f1f6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -510,10 +510,8 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * a record can have. *
  • `maxCharsPerColumn` (default `-1`): defines the maximum number of characters allowed * for any given value being read. By default, it is -1 meaning unlimited length
  • - *
  • `maxMalformedLogPerPartition` (default `10`): sets the maximum number of malformed rows - * Spark will log for each partition. Malformed records beyond this number will be ignored.
  • *
  • `mode` (default `PERMISSIVE`): allows a mode for dealing with corrupt records - * during parsing. + * during parsing. It supports the following case-insensitive modes. *
      *
    • `PERMISSIVE` : sets other fields to `null` when it meets a corrupted record, and puts * the malformed string into a field configured by `columnNameOfCorruptRecord`. To keep diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala index f6c6b6f56cd9..5d2c23ed9618 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala @@ -82,7 +82,8 @@ class CSVOptions( val delimiter = CSVUtils.toChar( parameters.getOrElse("sep", parameters.getOrElse("delimiter", ","))) - val parseMode = parameters.getOrElse("mode", "PERMISSIVE") + val parseMode: ParseMode = + parameters.get("mode").map(ParseMode.fromString).getOrElse(PermissiveMode) val charset = parameters.getOrElse("encoding", parameters.getOrElse("charset", StandardCharsets.UTF_8.name())) @@ -95,15 +96,6 @@ class CSVOptions( val ignoreLeadingWhiteSpaceFlag = getBool("ignoreLeadingWhiteSpace") val ignoreTrailingWhiteSpaceFlag = getBool("ignoreTrailingWhiteSpace") - // Parse mode flags - if (!ParseModes.isValidMode(parseMode)) { - logWarning(s"$parseMode is not a valid parse mode. Using ${ParseModes.DEFAULT}.") - } - - val failFast = ParseModes.isFailFastMode(parseMode) - val dropMalformed = ParseModes.isDropMalformedMode(parseMode) - val permissive = ParseModes.isPermissiveMode(parseMode) - val columnNameOfCorruptRecord = parameters.getOrElse("columnNameOfCorruptRecord", defaultColumnNameOfCorruptRecord) @@ -139,8 +131,6 @@ class CSVOptions( val escapeQuotes = getBool("escapeQuotes", true) - val maxMalformedLogPerPartition = getInt("maxMalformedLogPerPartition", 10) - val quoteAll = getBool("quoteAll", false) val inputBufferSize = 128 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala index 7475f8ec7933..e15c30b4374b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala @@ -25,6 +25,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.analysis.TypeCoercion import org.apache.spark.sql.catalyst.json.JacksonUtils.nextUntil import org.apache.spark.sql.catalyst.json.JSONOptions +import org.apache.spark.sql.catalyst.util.PermissiveMode import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -40,7 +41,7 @@ private[sql] object JsonInferSchema { json: RDD[T], configOptions: JSONOptions, createParser: (JsonFactory, T) => JsonParser): StructType = { - val shouldHandleCorruptRecord = configOptions.permissive + val shouldHandleCorruptRecord = configOptions.parseMode == PermissiveMode val columnNameOfCorruptRecord = configOptions.columnNameOfCorruptRecord // perform schema inference on each row and merge afterwards diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index 388ef182ce3a..f6e2fef74b8d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -260,7 +260,7 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo *
    • `maxCharsPerColumn` (default `-1`): defines the maximum number of characters allowed * for any given value being read. By default, it is -1 meaning unlimited length
    • *
    • `mode` (default `PERMISSIVE`): allows a mode for dealing with corrupt records - * during parsing. + * during parsing. It supports the following case-insensitive modes. *
        *
      • `PERMISSIVE` : sets other fields to `null` when it meets a corrupted record, and puts * the malformed string into a field configured by `columnNameOfCorruptRecord`. To keep diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index 598babfe0e7a..2600894ca303 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -992,9 +992,10 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { test("SPARK-18699 put malformed records in a `columnNameOfCorruptRecord` field") { Seq(false, true).foreach { wholeFile => val schema = new StructType().add("a", IntegerType).add("b", TimestampType) + // We use `PERMISSIVE` mode by default if invalid string is given. val df1 = spark .read - .option("mode", "PERMISSIVE") + .option("mode", "abcd") .option("wholeFile", wholeFile) .schema(schema) .csv(testFile(valueMalformedFile)) @@ -1008,7 +1009,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { val schemaWithCorrField1 = schema.add(columnNameOfCorruptRecord, StringType) val df2 = spark .read - .option("mode", "PERMISSIVE") + .option("mode", "Permissive") .option("columnNameOfCorruptRecord", columnNameOfCorruptRecord) .option("wholeFile", wholeFile) .schema(schemaWithCorrField1) @@ -1025,7 +1026,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { .add("b", TimestampType) val df3 = spark .read - .option("mode", "PERMISSIVE") + .option("mode", "permissive") .option("columnNameOfCorruptRecord", columnNameOfCorruptRecord) .option("wholeFile", wholeFile) .schema(schemaWithCorrField2) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 56fcf773f7dd..b09cef76d2be 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -1083,83 +1083,59 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } test("Corrupt records: PERMISSIVE mode, without designated column for malformed records") { - withTempView("jsonTable") { - val schema = StructType( - StructField("a", StringType, true) :: - StructField("b", StringType, true) :: - StructField("c", StringType, true) :: Nil) + val schema = StructType( + StructField("a", StringType, true) :: + StructField("b", StringType, true) :: + StructField("c", StringType, true) :: Nil) - val jsonDF = spark.read.schema(schema).json(corruptRecords) - jsonDF.createOrReplaceTempView("jsonTable") + val jsonDF = spark.read.schema(schema).json(corruptRecords) - checkAnswer( - sql( - """ - |SELECT a, b, c - |FROM jsonTable - """.stripMargin), - Seq( - // Corrupted records are replaced with null - Row(null, null, null), - Row(null, null, null), - Row(null, null, null), - Row("str_a_4", "str_b_4", "str_c_4"), - Row(null, null, null)) - ) - } + checkAnswer( + jsonDF.select($"a", $"b", $"c"), + Seq( + // Corrupted records are replaced with null + Row(null, null, null), + Row(null, null, null), + Row(null, null, null), + Row("str_a_4", "str_b_4", "str_c_4"), + Row(null, null, null)) + ) } test("Corrupt records: PERMISSIVE mode, with designated column for malformed records") { // Test if we can query corrupt records. withSQLConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD.key -> "_unparsed") { - withTempView("jsonTable") { - val jsonDF = spark.read.json(corruptRecords) - jsonDF.createOrReplaceTempView("jsonTable") - val schema = StructType( - StructField("_unparsed", StringType, true) :: + val jsonDF = spark.read.json(corruptRecords) + val schema = StructType( + StructField("_unparsed", StringType, true) :: StructField("a", StringType, true) :: StructField("b", StringType, true) :: StructField("c", StringType, true) :: Nil) - assert(schema === jsonDF.schema) - - // In HiveContext, backticks should be used to access columns starting with a underscore. - checkAnswer( - sql( - """ - |SELECT a, b, c, _unparsed - |FROM jsonTable - """.stripMargin), - Row(null, null, null, "{") :: - Row(null, null, null, """{"a":1, b:2}""") :: - Row(null, null, null, """{"a":{, b:3}""") :: - Row("str_a_4", "str_b_4", "str_c_4", null) :: - Row(null, null, null, "]") :: Nil - ) - - checkAnswer( - sql( - """ - |SELECT a, b, c - |FROM jsonTable - |WHERE _unparsed IS NULL - """.stripMargin), - Row("str_a_4", "str_b_4", "str_c_4") - ) - - checkAnswer( - sql( - """ - |SELECT _unparsed - |FROM jsonTable - |WHERE _unparsed IS NOT NULL - """.stripMargin), - Row("{") :: - Row("""{"a":1, b:2}""") :: - Row("""{"a":{, b:3}""") :: - Row("]") :: Nil - ) - } + assert(schema === jsonDF.schema) + + // In HiveContext, backticks should be used to access columns starting with a underscore. + checkAnswer( + jsonDF.select($"a", $"b", $"c", $"_unparsed"), + Row(null, null, null, "{") :: + Row(null, null, null, """{"a":1, b:2}""") :: + Row(null, null, null, """{"a":{, b:3}""") :: + Row("str_a_4", "str_b_4", "str_c_4", null) :: + Row(null, null, null, "]") :: Nil + ) + + checkAnswer( + jsonDF.filter($"_unparsed".isNull).select($"a", $"b", $"c"), + Row("str_a_4", "str_b_4", "str_c_4") + ) + + checkAnswer( + jsonDF.filter($"_unparsed".isNotNull).select($"_unparsed"), + Row("{") :: + Row("""{"a":1, b:2}""") :: + Row("""{"a":{, b:3}""") :: + Row("]") :: Nil + ) } } @@ -1952,19 +1928,20 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { StructField("c", StringType, true) :: Nil) val errMsg = intercept[AnalysisException] { spark.read - .option("mode", "PERMISSIVE") + .option("mode", "Permissive") .option("columnNameOfCorruptRecord", columnNameOfCorruptRecord) .schema(schema) .json(corruptRecords) }.getMessage assert(errMsg.startsWith("The field for corrupt records must be string type and nullable")) + // We use `PERMISSIVE` mode by default if invalid string is given. withTempPath { dir => val path = dir.getCanonicalPath corruptRecords.toDF("value").write.text(path) val errMsg = intercept[AnalysisException] { spark.read - .option("mode", "PERMISSIVE") + .option("mode", "permm") .option("columnNameOfCorruptRecord", columnNameOfCorruptRecord) .schema(schema) .json(path) From 80fd070389a9c8ffa342d7b11f1ab2ea92e0f562 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Wed, 22 Mar 2017 09:58:46 -0700 Subject: [PATCH 104/512] [SPARK-20018][SQL] Pivot with timestamp and count should not print internal representation ## What changes were proposed in this pull request? Currently, when we perform count with timestamp types, it prints the internal representation as the column name as below: ```scala Seq(new java.sql.Timestamp(1)).toDF("a").groupBy("a").pivot("a").count().show() ``` ``` +--------------------+----+ | a|1000| +--------------------+----+ |1969-12-31 16:00:...| 1| +--------------------+----+ ``` This PR proposes to use external Scala value instead of the internal representation in the column names as below: ``` +--------------------+-----------------------+ | a|1969-12-31 16:00:00.001| +--------------------+-----------------------+ |1969-12-31 16:00:...| 1| +--------------------+-----------------------+ ``` ## How was this patch tested? Unit test in `DataFramePivotSuite` and manual tests. Author: hyukjinkwon Closes #17348 from HyukjinKwon/SPARK-20018. --- .../spark/sql/catalyst/analysis/Analyzer.scala | 6 ++++-- .../apache/spark/sql/DataFramePivotSuite.scala | 18 +++++++++++++++++- 2 files changed, 21 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 574f91b09912..036ed060d9ef 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -486,14 +486,16 @@ class Analyzer( case Pivot(groupByExprs, pivotColumn, pivotValues, aggregates, child) => val singleAgg = aggregates.size == 1 def outputName(value: Literal, aggregate: Expression): String = { + val utf8Value = Cast(value, StringType, Some(conf.sessionLocalTimeZone)).eval(EmptyRow) + val stringValue: String = Option(utf8Value).map(_.toString).getOrElse("null") if (singleAgg) { - value.toString + stringValue } else { val suffix = aggregate match { case n: NamedExpression => n.name case _ => toPrettySQL(aggregate) } - value + "_" + suffix + stringValue + "_" + suffix } } if (aggregates.forall(a => PivotFirst.supportsDataType(a.dataType))) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala index ca3cb5676742..6ca9ee57e8f4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ -class DataFramePivotSuite extends QueryTest with SharedSQLContext{ +class DataFramePivotSuite extends QueryTest with SharedSQLContext { import testImplicits._ test("pivot courses") { @@ -230,4 +230,20 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext{ .groupBy($"a").pivot("a").agg(min($"b")), Row(null, Seq(null, 7), null) :: Row(1, null, Seq(1, 7)) :: Nil) } + + test("pivot with timestamp and count should not print internal representation") { + val ts = "2012-12-31 16:00:10.011" + val tsWithZone = "2013-01-01 00:00:10.011" + + withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> "GMT") { + val df = Seq(java.sql.Timestamp.valueOf(ts)).toDF("a").groupBy("a").pivot("a").count() + val expected = StructType( + StructField("a", TimestampType) :: + StructField(tsWithZone, LongType) :: Nil) + assert(df.schema == expected) + // String representation of timestamp with timezone should take the time difference + // into account. + checkAnswer(df.select($"a".cast(StringType)), Row(tsWithZone)) + } + } } From 82b598b963a21ae9d6a2a9638e86b4165c2a78c9 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Wed, 22 Mar 2017 12:30:36 -0700 Subject: [PATCH 105/512] [SPARK-20057][SS] Renamed KeyedState to GroupState in mapGroupsWithState ## What changes were proposed in this pull request? Since the state is tied a "group" in the "mapGroupsWithState" operations, its better to call the state "GroupState" instead of a key. This would make it more general if you extends this operation to RelationGroupedDataset and python APIs. ## How was this patch tested? Existing unit tests. Author: Tathagata Das Closes #17385 from tdas/SPARK-20057. --- ...ateTimeout.java => GroupStateTimeout.java} | 18 +-- .../sql/catalyst/plans/logical/object.scala | 18 +-- ...e.java => JavaGroupStateTimeoutSuite.java} | 8 +- .../FlatMapGroupsWithStateFunction.java | 4 +- .../function/MapGroupsWithStateFunction.java | 4 +- .../spark/sql/KeyValueGroupedDataset.scala | 46 +++---- .../apache/spark/sql/execution/objects.scala | 8 +- .../FlatMapGroupsWithStateExec.scala | 16 +-- ...edStateImpl.scala => GroupStateImpl.scala} | 19 +-- .../streaming/statefulOperators.scala | 4 +- .../{KeyedState.scala => GroupState.scala} | 68 +++++----- .../apache/spark/sql/JavaDatasetSuite.java | 4 +- .../FlatMapGroupsWithStateSuite.scala | 122 +++++++++--------- 13 files changed, 172 insertions(+), 167 deletions(-) rename sql/catalyst/src/main/java/org/apache/spark/sql/streaming/{KeyedStateTimeout.java => GroupStateTimeout.java} (79%) rename sql/catalyst/src/test/java/org/apache/spark/sql/streaming/{JavaKeyedStateTimeoutSuite.java => JavaGroupStateTimeoutSuite.java} (70%) rename sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/{KeyedStateImpl.scala => GroupStateImpl.scala} (94%) rename sql/core/src/main/scala/org/apache/spark/sql/streaming/{KeyedState.scala => GroupState.scala} (84%) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/KeyedStateTimeout.java b/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/GroupStateTimeout.java similarity index 79% rename from sql/catalyst/src/main/java/org/apache/spark/sql/streaming/KeyedStateTimeout.java rename to sql/catalyst/src/main/java/org/apache/spark/sql/streaming/GroupStateTimeout.java index e2e7ab1d2609..bd5e2d7ecca9 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/KeyedStateTimeout.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/GroupStateTimeout.java @@ -24,31 +24,31 @@ /** * Represents the type of timeouts possible for the Dataset operations * `mapGroupsWithState` and `flatMapGroupsWithState`. See documentation on - * `KeyedState` for more details. + * `GroupState` for more details. * * @since 2.2.0 */ @Experimental @InterfaceStability.Evolving -public class KeyedStateTimeout { +public class GroupStateTimeout { /** * Timeout based on processing time. The duration of timeout can be set for each group in - * `map/flatMapGroupsWithState` by calling `KeyedState.setTimeoutDuration()`. See documentation - * on `KeyedState` for more details. + * `map/flatMapGroupsWithState` by calling `GroupState.setTimeoutDuration()`. See documentation + * on `GroupState` for more details. */ - public static KeyedStateTimeout ProcessingTimeTimeout() { return ProcessingTimeTimeout$.MODULE$; } + public static GroupStateTimeout ProcessingTimeTimeout() { return ProcessingTimeTimeout$.MODULE$; } /** * Timeout based on event-time. The event-time timestamp for timeout can be set for each - * group in `map/flatMapGroupsWithState` by calling `KeyedState.setTimeoutTimestamp()`. + * group in `map/flatMapGroupsWithState` by calling `GroupState.setTimeoutTimestamp()`. * In addition, you have to define the watermark in the query using `Dataset.withWatermark`. * When the watermark advances beyond the set timestamp of a group and the group has not * received any data, then the group times out. See documentation on - * `KeyedState` for more details. + * `GroupState` for more details. */ - public static KeyedStateTimeout EventTimeTimeout() { return EventTimeTimeout$.MODULE$; } + public static GroupStateTimeout EventTimeTimeout() { return EventTimeTimeout$.MODULE$; } /** No timeout. */ - public static KeyedStateTimeout NoTimeout() { return NoTimeout$.MODULE$; } + public static GroupStateTimeout NoTimeout() { return NoTimeout$.MODULE$; } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala index e0ecf8c5f264..6225b3fa4299 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.analysis.UnresolvedDeserializer import org.apache.spark.sql.catalyst.encoders._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.objects.Invoke -import org.apache.spark.sql.streaming.{KeyedStateTimeout, OutputMode } +import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode } import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -351,22 +351,22 @@ case class MapGroups( child: LogicalPlan) extends UnaryNode with ObjectProducer /** Internal class representing State */ -trait LogicalKeyedState[S] +trait LogicalGroupState[S] /** Types of timeouts used in FlatMapGroupsWithState */ -case object NoTimeout extends KeyedStateTimeout -case object ProcessingTimeTimeout extends KeyedStateTimeout -case object EventTimeTimeout extends KeyedStateTimeout +case object NoTimeout extends GroupStateTimeout +case object ProcessingTimeTimeout extends GroupStateTimeout +case object EventTimeTimeout extends GroupStateTimeout /** Factory for constructing new `MapGroupsWithState` nodes. */ object FlatMapGroupsWithState { def apply[K: Encoder, V: Encoder, S: Encoder, U: Encoder]( - func: (Any, Iterator[Any], LogicalKeyedState[Any]) => Iterator[Any], + func: (Any, Iterator[Any], LogicalGroupState[Any]) => Iterator[Any], groupingAttributes: Seq[Attribute], dataAttributes: Seq[Attribute], outputMode: OutputMode, isMapGroupsWithState: Boolean, - timeout: KeyedStateTimeout, + timeout: GroupStateTimeout, child: LogicalPlan): LogicalPlan = { val encoder = encoderFor[S] @@ -404,7 +404,7 @@ object FlatMapGroupsWithState { * @param timeout used to timeout groups that have not received data in a while */ case class FlatMapGroupsWithState( - func: (Any, Iterator[Any], LogicalKeyedState[Any]) => Iterator[Any], + func: (Any, Iterator[Any], LogicalGroupState[Any]) => Iterator[Any], keyDeserializer: Expression, valueDeserializer: Expression, groupingAttributes: Seq[Attribute], @@ -413,7 +413,7 @@ case class FlatMapGroupsWithState( stateEncoder: ExpressionEncoder[Any], outputMode: OutputMode, isMapGroupsWithState: Boolean = false, - timeout: KeyedStateTimeout, + timeout: GroupStateTimeout, child: LogicalPlan) extends UnaryNode with ObjectProducer { if (isMapGroupsWithState) { diff --git a/sql/catalyst/src/test/java/org/apache/spark/sql/streaming/JavaKeyedStateTimeoutSuite.java b/sql/catalyst/src/test/java/org/apache/spark/sql/streaming/JavaGroupStateTimeoutSuite.java similarity index 70% rename from sql/catalyst/src/test/java/org/apache/spark/sql/streaming/JavaKeyedStateTimeoutSuite.java rename to sql/catalyst/src/test/java/org/apache/spark/sql/streaming/JavaGroupStateTimeoutSuite.java index 02c94b0b3244..2e8f2e3fd9f4 100644 --- a/sql/catalyst/src/test/java/org/apache/spark/sql/streaming/JavaKeyedStateTimeoutSuite.java +++ b/sql/catalyst/src/test/java/org/apache/spark/sql/streaming/JavaGroupStateTimeoutSuite.java @@ -17,13 +17,17 @@ package org.apache.spark.sql.streaming; +import org.apache.spark.sql.catalyst.plans.logical.EventTimeTimeout$; +import org.apache.spark.sql.catalyst.plans.logical.NoTimeout$; import org.apache.spark.sql.catalyst.plans.logical.ProcessingTimeTimeout$; import org.junit.Test; -public class JavaKeyedStateTimeoutSuite { +public class JavaGroupStateTimeoutSuite { @Test public void testTimeouts() { - assert(KeyedStateTimeout.ProcessingTimeTimeout() == ProcessingTimeTimeout$.MODULE$); + assert (GroupStateTimeout.ProcessingTimeTimeout() == ProcessingTimeTimeout$.MODULE$); + assert (GroupStateTimeout.EventTimeTimeout() == EventTimeTimeout$.MODULE$); + assert (GroupStateTimeout.NoTimeout() == NoTimeout$.MODULE$); } } diff --git a/sql/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsWithStateFunction.java b/sql/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsWithStateFunction.java index bdda8aaf734d..026b37cabbf1 100644 --- a/sql/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsWithStateFunction.java +++ b/sql/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsWithStateFunction.java @@ -22,7 +22,7 @@ import org.apache.spark.annotation.Experimental; import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.streaming.KeyedState; +import org.apache.spark.sql.streaming.GroupState; /** * ::Experimental:: @@ -35,5 +35,5 @@ @Experimental @InterfaceStability.Evolving public interface FlatMapGroupsWithStateFunction extends Serializable { - Iterator call(K key, Iterator values, KeyedState state) throws Exception; + Iterator call(K key, Iterator values, GroupState state) throws Exception; } diff --git a/sql/core/src/main/java/org/apache/spark/api/java/function/MapGroupsWithStateFunction.java b/sql/core/src/main/java/org/apache/spark/api/java/function/MapGroupsWithStateFunction.java index 70f3f01a8e9d..353e9886a8a5 100644 --- a/sql/core/src/main/java/org/apache/spark/api/java/function/MapGroupsWithStateFunction.java +++ b/sql/core/src/main/java/org/apache/spark/api/java/function/MapGroupsWithStateFunction.java @@ -22,7 +22,7 @@ import org.apache.spark.annotation.Experimental; import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.streaming.KeyedState; +import org.apache.spark.sql.streaming.GroupState; /** * ::Experimental:: @@ -34,5 +34,5 @@ @Experimental @InterfaceStability.Evolving public interface MapGroupsWithStateFunction extends Serializable { - R call(K key, Iterator values, KeyedState state) throws Exception; + R call(K key, Iterator values, GroupState state) throws Exception; } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala index 96437f868a6e..87c562176887 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.streaming.InternalOutputModes import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.expressions.ReduceAggregator -import org.apache.spark.sql.streaming.{KeyedState, KeyedStateTimeout, OutputMode} +import org.apache.spark.sql.streaming.{GroupState, GroupStateTimeout, OutputMode} /** * :: Experimental :: @@ -228,7 +228,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( * For a static batch Dataset, the function will be invoked once per group. For a streaming * Dataset, the function will be invoked for each group repeatedly in every trigger, and * updates to each group's state will be saved across invocations. - * See [[org.apache.spark.sql.streaming.KeyedState]] for more details. + * See [[org.apache.spark.sql.streaming.GroupState]] for more details. * * @tparam S The type of the user-defined state. Must be encodable to Spark SQL types. * @tparam U The type of the output objects. Must be encodable to Spark SQL types. @@ -240,17 +240,17 @@ class KeyValueGroupedDataset[K, V] private[sql]( @Experimental @InterfaceStability.Evolving def mapGroupsWithState[S: Encoder, U: Encoder]( - func: (K, Iterator[V], KeyedState[S]) => U): Dataset[U] = { - val flatMapFunc = (key: K, it: Iterator[V], s: KeyedState[S]) => Iterator(func(key, it, s)) + func: (K, Iterator[V], GroupState[S]) => U): Dataset[U] = { + val flatMapFunc = (key: K, it: Iterator[V], s: GroupState[S]) => Iterator(func(key, it, s)) Dataset[U]( sparkSession, FlatMapGroupsWithState[K, V, S, U]( - flatMapFunc.asInstanceOf[(Any, Iterator[Any], LogicalKeyedState[Any]) => Iterator[Any]], + flatMapFunc.asInstanceOf[(Any, Iterator[Any], LogicalGroupState[Any]) => Iterator[Any]], groupingAttributes, dataAttributes, OutputMode.Update, isMapGroupsWithState = true, - KeyedStateTimeout.NoTimeout, + GroupStateTimeout.NoTimeout, child = logicalPlan)) } @@ -262,7 +262,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( * For a static batch Dataset, the function will be invoked once per group. For a streaming * Dataset, the function will be invoked for each group repeatedly in every trigger, and * updates to each group's state will be saved across invocations. - * See [[org.apache.spark.sql.streaming.KeyedState]] for more details. + * See [[org.apache.spark.sql.streaming.GroupState]] for more details. * * @tparam S The type of the user-defined state. Must be encodable to Spark SQL types. * @tparam U The type of the output objects. Must be encodable to Spark SQL types. @@ -275,13 +275,13 @@ class KeyValueGroupedDataset[K, V] private[sql]( @Experimental @InterfaceStability.Evolving def mapGroupsWithState[S: Encoder, U: Encoder]( - timeoutConf: KeyedStateTimeout)( - func: (K, Iterator[V], KeyedState[S]) => U): Dataset[U] = { - val flatMapFunc = (key: K, it: Iterator[V], s: KeyedState[S]) => Iterator(func(key, it, s)) + timeoutConf: GroupStateTimeout)( + func: (K, Iterator[V], GroupState[S]) => U): Dataset[U] = { + val flatMapFunc = (key: K, it: Iterator[V], s: GroupState[S]) => Iterator(func(key, it, s)) Dataset[U]( sparkSession, FlatMapGroupsWithState[K, V, S, U]( - flatMapFunc.asInstanceOf[(Any, Iterator[Any], LogicalKeyedState[Any]) => Iterator[Any]], + flatMapFunc.asInstanceOf[(Any, Iterator[Any], LogicalGroupState[Any]) => Iterator[Any]], groupingAttributes, dataAttributes, OutputMode.Update, @@ -298,7 +298,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( * For a static batch Dataset, the function will be invoked once per group. For a streaming * Dataset, the function will be invoked for each group repeatedly in every trigger, and * updates to each group's state will be saved across invocations. - * See [[KeyedState]] for more details. + * See [[GroupState]] for more details. * * @tparam S The type of the user-defined state. Must be encodable to Spark SQL types. * @tparam U The type of the output objects. Must be encodable to Spark SQL types. @@ -316,7 +316,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( stateEncoder: Encoder[S], outputEncoder: Encoder[U]): Dataset[U] = { mapGroupsWithState[S, U]( - (key: K, it: Iterator[V], s: KeyedState[S]) => func.call(key, it.asJava, s) + (key: K, it: Iterator[V], s: GroupState[S]) => func.call(key, it.asJava, s) )(stateEncoder, outputEncoder) } @@ -328,7 +328,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( * For a static batch Dataset, the function will be invoked once per group. For a streaming * Dataset, the function will be invoked for each group repeatedly in every trigger, and * updates to each group's state will be saved across invocations. - * See [[KeyedState]] for more details. + * See [[GroupState]] for more details. * * @tparam S The type of the user-defined state. Must be encodable to Spark SQL types. * @tparam U The type of the output objects. Must be encodable to Spark SQL types. @@ -346,9 +346,9 @@ class KeyValueGroupedDataset[K, V] private[sql]( func: MapGroupsWithStateFunction[K, V, S, U], stateEncoder: Encoder[S], outputEncoder: Encoder[U], - timeoutConf: KeyedStateTimeout): Dataset[U] = { + timeoutConf: GroupStateTimeout): Dataset[U] = { mapGroupsWithState[S, U]( - (key: K, it: Iterator[V], s: KeyedState[S]) => func.call(key, it.asJava, s) + (key: K, it: Iterator[V], s: GroupState[S]) => func.call(key, it.asJava, s) )(stateEncoder, outputEncoder) } @@ -360,7 +360,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( * For a static batch Dataset, the function will be invoked once per group. For a streaming * Dataset, the function will be invoked for each group repeatedly in every trigger, and * updates to each group's state will be saved across invocations. - * See [[KeyedState]] for more details. + * See [[GroupState]] for more details. * * @tparam S The type of the user-defined state. Must be encodable to Spark SQL types. * @tparam U The type of the output objects. Must be encodable to Spark SQL types. @@ -375,15 +375,15 @@ class KeyValueGroupedDataset[K, V] private[sql]( @InterfaceStability.Evolving def flatMapGroupsWithState[S: Encoder, U: Encoder]( outputMode: OutputMode, - timeoutConf: KeyedStateTimeout)( - func: (K, Iterator[V], KeyedState[S]) => Iterator[U]): Dataset[U] = { + timeoutConf: GroupStateTimeout)( + func: (K, Iterator[V], GroupState[S]) => Iterator[U]): Dataset[U] = { if (outputMode != OutputMode.Append && outputMode != OutputMode.Update) { throw new IllegalArgumentException("The output mode of function should be append or update") } Dataset[U]( sparkSession, FlatMapGroupsWithState[K, V, S, U]( - func.asInstanceOf[(Any, Iterator[Any], LogicalKeyedState[Any]) => Iterator[Any]], + func.asInstanceOf[(Any, Iterator[Any], LogicalGroupState[Any]) => Iterator[Any]], groupingAttributes, dataAttributes, outputMode, @@ -400,7 +400,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( * For a static batch Dataset, the function will be invoked once per group. For a streaming * Dataset, the function will be invoked for each group repeatedly in every trigger, and * updates to each group's state will be saved across invocations. - * See [[KeyedState]] for more details. + * See [[GroupState]] for more details. * * @tparam S The type of the user-defined state. Must be encodable to Spark SQL types. * @tparam U The type of the output objects. Must be encodable to Spark SQL types. @@ -420,8 +420,8 @@ class KeyValueGroupedDataset[K, V] private[sql]( outputMode: OutputMode, stateEncoder: Encoder[S], outputEncoder: Encoder[U], - timeoutConf: KeyedStateTimeout): Dataset[U] = { - val f = (key: K, it: Iterator[V], s: KeyedState[S]) => func.call(key, it.asJava, s).asScala + timeoutConf: GroupStateTimeout): Dataset[U] = { + val f = (key: K, it: Iterator[V], s: GroupState[S]) => func.call(key, it.asJava, s).asScala flatMapGroupsWithState[S, U](outputMode, timeoutConf)(f)(stateEncoder, outputEncoder) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala index fdd1bcc94be2..48c7b80bffe0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala @@ -31,8 +31,8 @@ import org.apache.spark.sql.catalyst.expressions.objects.Invoke import org.apache.spark.sql.catalyst.plans.logical.FunctionUtils import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.plans.logical.LogicalKeyedState -import org.apache.spark.sql.execution.streaming.KeyedStateImpl +import org.apache.spark.sql.catalyst.plans.logical.LogicalGroupState +import org.apache.spark.sql.execution.streaming.GroupStateImpl import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -355,14 +355,14 @@ case class MapGroupsExec( object MapGroupsExec { def apply( - func: (Any, Iterator[Any], LogicalKeyedState[Any]) => TraversableOnce[Any], + func: (Any, Iterator[Any], LogicalGroupState[Any]) => TraversableOnce[Any], keyDeserializer: Expression, valueDeserializer: Expression, groupingAttributes: Seq[Attribute], dataAttributes: Seq[Attribute], outputObjAttr: Attribute, child: SparkPlan): MapGroupsExec = { - val f = (key: Any, values: Iterator[Any]) => func(key, values, new KeyedStateImpl[Any](None)) + val f = (key: Any, values: Iterator[Any]) => func(key, values, new GroupStateImpl[Any](None)) new MapGroupsExec(f, keyDeserializer, valueDeserializer, groupingAttributes, dataAttributes, outputObjAttr, child) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala index 52ad70c7dc88..c7262ea97200 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala @@ -23,9 +23,9 @@ import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Attribut import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution} import org.apache.spark.sql.execution._ -import org.apache.spark.sql.execution.streaming.KeyedStateImpl.NO_TIMESTAMP +import org.apache.spark.sql.execution.streaming.GroupStateImpl.NO_TIMESTAMP import org.apache.spark.sql.execution.streaming.state._ -import org.apache.spark.sql.streaming.{KeyedStateTimeout, OutputMode} +import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode} import org.apache.spark.sql.types.IntegerType import org.apache.spark.util.CompletionIterator @@ -44,7 +44,7 @@ import org.apache.spark.util.CompletionIterator * @param batchTimestampMs processing timestamp of the current batch. */ case class FlatMapGroupsWithStateExec( - func: (Any, Iterator[Any], LogicalKeyedState[Any]) => Iterator[Any], + func: (Any, Iterator[Any], LogicalGroupState[Any]) => Iterator[Any], keyDeserializer: Expression, valueDeserializer: Expression, groupingAttributes: Seq[Attribute], @@ -53,13 +53,13 @@ case class FlatMapGroupsWithStateExec( stateId: Option[OperatorStateId], stateEncoder: ExpressionEncoder[Any], outputMode: OutputMode, - timeoutConf: KeyedStateTimeout, + timeoutConf: GroupStateTimeout, batchTimestampMs: Option[Long], override val eventTimeWatermark: Option[Long], child: SparkPlan ) extends UnaryExecNode with ObjectProducerExec with StateStoreWriter with WatermarkSupport { - import KeyedStateImpl._ + import GroupStateImpl._ private val isTimeoutEnabled = timeoutConf != NoTimeout private val timestampTimeoutAttribute = @@ -147,7 +147,7 @@ case class FlatMapGroupsWithStateExec( private val stateSerializer = { val encoderSerializer = stateEncoder.namedExpressions if (isTimeoutEnabled) { - encoderSerializer :+ Literal(KeyedStateImpl.NO_TIMESTAMP) + encoderSerializer :+ Literal(GroupStateImpl.NO_TIMESTAMP) } else { encoderSerializer } @@ -211,7 +211,7 @@ case class FlatMapGroupsWithStateExec( val keyObj = getKeyObj(keyRow) // convert key to objects val valueObjIter = valueRowIter.map(getValueObj.apply) // convert value rows to objects val stateObjOption = getStateObj(prevStateRowOption) - val keyedState = new KeyedStateImpl( + val keyedState = new GroupStateImpl( stateObjOption, batchTimestampMs.getOrElse(NO_TIMESTAMP), eventTimeWatermark.getOrElse(NO_TIMESTAMP), @@ -247,7 +247,7 @@ case class FlatMapGroupsWithStateExec( if (shouldWriteState) { if (stateRowToWrite == null) { - // This should never happen because checks in KeyedStateImpl should avoid cases + // This should never happen because checks in GroupStateImpl should avoid cases // where empty state would need to be written throw new IllegalStateException("Attempting to write empty state") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/KeyedStateImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/GroupStateImpl.scala similarity index 94% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/KeyedStateImpl.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/GroupStateImpl.scala index edfd35bd5dd7..148d92247d6f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/KeyedStateImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/GroupStateImpl.scala @@ -22,13 +22,14 @@ import java.sql.Date import org.apache.commons.lang3.StringUtils import org.apache.spark.sql.catalyst.plans.logical.{EventTimeTimeout, ProcessingTimeTimeout} -import org.apache.spark.sql.execution.streaming.KeyedStateImpl._ -import org.apache.spark.sql.streaming.{KeyedState, KeyedStateTimeout} +import org.apache.spark.sql.execution.streaming.GroupStateImpl._ +import org.apache.spark.sql.streaming.{GroupState, GroupStateTimeout} import org.apache.spark.unsafe.types.CalendarInterval /** - * Internal implementation of the [[KeyedState]] interface. Methods are not thread-safe. + * Internal implementation of the [[GroupState]] interface. Methods are not thread-safe. + * * @param optionalValue Optional value of the state * @param batchProcessingTimeMs Processing time of current batch, used to calculate timestamp * for processing time timeouts @@ -37,19 +38,19 @@ import org.apache.spark.unsafe.types.CalendarInterval * @param hasTimedOut Whether the key for which this state wrapped is being created is * getting timed out or not. */ -private[sql] class KeyedStateImpl[S]( +private[sql] class GroupStateImpl[S]( optionalValue: Option[S], batchProcessingTimeMs: Long, eventTimeWatermarkMs: Long, - timeoutConf: KeyedStateTimeout, - override val hasTimedOut: Boolean) extends KeyedState[S] { + timeoutConf: GroupStateTimeout, + override val hasTimedOut: Boolean) extends GroupState[S] { // Constructor to create dummy state when using mapGroupsWithState in a batch query def this(optionalValue: Option[S]) = this( optionalValue, batchProcessingTimeMs = NO_TIMESTAMP, eventTimeWatermarkMs = NO_TIMESTAMP, - timeoutConf = KeyedStateTimeout.NoTimeout, + timeoutConf = GroupStateTimeout.NoTimeout, hasTimedOut = false) private var value: S = optionalValue.getOrElse(null.asInstanceOf[S]) private var defined: Boolean = optionalValue.isDefined @@ -169,7 +170,7 @@ private[sql] class KeyedStateImpl[S]( } override def toString: String = { - s"KeyedState(${getOption.map(_.toString).getOrElse("")})" + s"GroupState(${getOption.map(_.toString).getOrElse("")})" } // ========= Internal API ========= @@ -221,7 +222,7 @@ private[sql] class KeyedStateImpl[S]( } -private[sql] object KeyedStateImpl { +private[sql] object GroupStateImpl { // Value used represent the lack of valid timestamp as a long val NO_TIMESTAMP = -1L } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index f72144a25d5c..8dbda298c87b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -23,13 +23,13 @@ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjection, Predicate} -import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, LogicalKeyedState, ProcessingTimeTimeout} +import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, LogicalGroupState, ProcessingTimeTimeout} import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution, Partitioning} import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.sql.execution.streaming.state._ -import org.apache.spark.sql.streaming.{KeyedStateTimeout, OutputMode} +import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode} import org.apache.spark.sql.types._ import org.apache.spark.util.CompletionIterator diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/KeyedState.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/GroupState.scala similarity index 84% rename from sql/core/src/main/scala/org/apache/spark/sql/streaming/KeyedState.scala rename to sql/core/src/main/scala/org/apache/spark/sql/streaming/GroupState.scala index 461de04f6bbe..60a4d0d8f98a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/KeyedState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/GroupState.scala @@ -19,14 +19,13 @@ package org.apache.spark.sql.streaming import org.apache.spark.annotation.{Experimental, InterfaceStability} import org.apache.spark.sql.{Encoder, KeyValueGroupedDataset} -import org.apache.spark.sql.catalyst.plans.logical.LogicalKeyedState +import org.apache.spark.sql.catalyst.plans.logical.LogicalGroupState /** * :: Experimental :: * - * Wrapper class for interacting with keyed state data in `mapGroupsWithState` and - * `flatMapGroupsWithState` operations on - * [[KeyValueGroupedDataset]]. + * Wrapper class for interacting with per-group state data in `mapGroupsWithState` and + * `flatMapGroupsWithState` operations on [[KeyValueGroupedDataset]]. * * Detail description on `[map/flatMap]GroupsWithState` operation * -------------------------------------------------------------- @@ -37,11 +36,11 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalKeyedState * Dataset, the function will be invoked for each group repeatedly in every trigger. * That is, in every batch of the `streaming.StreamingQuery`, * the function will be invoked once for each group that has data in the trigger. Furthermore, - * if timeout is set, then the function will invoked on timed out keys (more detail below). + * if timeout is set, then the function will invoked on timed out groups (more detail below). * * The function is invoked with following parameters. * - The key of the group. - * - An iterator containing all the values for this key. + * - An iterator containing all the values for this group. * - A user-defined state object set by previous invocations of the given function. * In case of a batch Dataset, there is only one invocation and state object will be empty as * there is no prior state. Essentially, for batch Datasets, `[map/flatMap]GroupsWithState` @@ -55,57 +54,58 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalKeyedState * batch, nor with streaming Datasets. * - All the data will be shuffled before applying the function. * - If timeout is set, then the function will also be called with no values. - * See more details on `KeyedStateTimeout` below. + * See more details on `GroupStateTimeout` below. * - * Important points to note about using `KeyedState`. + * Important points to note about using `GroupState`. * - The value of the state cannot be null. So updating state with null will throw * `IllegalArgumentException`. - * - Operations on `KeyedState` are not thread-safe. This is to avoid memory barriers. + * - Operations on `GroupState` are not thread-safe. This is to avoid memory barriers. * - If `remove()` is called, then `exists()` will return `false`, * `get()` will throw `NoSuchElementException` and `getOption()` will return `None` * - After that, if `update(newState)` is called, then `exists()` will again return `true`, * `get()` and `getOption()`will return the updated value. * - * Important points to note about using `KeyedStateTimeout`. - * - The timeout type is a global param across all the keys (set as `timeout` param in + * Important points to note about using `GroupStateTimeout`. + * - The timeout type is a global param across all the groups (set as `timeout` param in * `[map|flatMap]GroupsWithState`, but the exact timeout duration/timestamp is configurable per - * key by calling `setTimeout...()` in `KeyedState`. + * group by calling `setTimeout...()` in `GroupState`. * - Timeouts can be either based on processing time (i.e. - * [[KeyedStateTimeout.ProcessingTimeTimeout]]) or event time (i.e. - * [[KeyedStateTimeout.EventTimeTimeout]]). + * [[GroupStateTimeout.ProcessingTimeTimeout]]) or event time (i.e. + * [[GroupStateTimeout.EventTimeTimeout]]). * - With `ProcessingTimeTimeout`, the timeout duration can be set by calling - * `KeyedState.setTimeoutDuration`. The timeout will occur when the clock has advanced by the set + * `GroupState.setTimeoutDuration`. The timeout will occur when the clock has advanced by the set * duration. Guarantees provided by this timeout with a duration of D ms are as follows: * - Timeout will never be occur before the clock time has advanced by D ms * - Timeout will occur eventually when there is a trigger in the query * (i.e. after D ms). So there is a no strict upper bound on when the timeout would occur. * For example, the trigger interval of the query will affect when the timeout actually occurs. - * If there is no data in the stream (for any key) for a while, then their will not be + * If there is no data in the stream (for any group) for a while, then their will not be * any trigger and timeout function call will not occur until there is data. * - Since the processing time timeout is based on the clock time, it is affected by the * variations in the system clock (i.e. time zone changes, clock skew, etc.). * - With `EventTimeTimeout`, the user also has to specify the the the event time watermark in * the query using `Dataset.withWatermark()`. With this setting, data that is older than the - * watermark are filtered out. The timeout can be enabled for a key by setting a timestamp using - * `KeyedState.setTimeoutTimestamp()`, and the timeout would occur when the watermark advances - * beyond the set timestamp. You can control the timeout delay by two parameters - (i) watermark - * delay and an additional duration beyond the timestamp in the event (which is guaranteed to - * > watermark due to the filtering). Guarantees provided by this timeout are as follows: + * watermark are filtered out. The timeout can be set for a group by setting a timeout timestamp + * using`GroupState.setTimeoutTimestamp()`, and the timeout would occur when the watermark + * advances beyond the set timestamp. You can control the timeout delay by two parameters - + * (i) watermark delay and an additional duration beyond the timestamp in the event (which + * is guaranteed to be newer than watermark due to the filtering). Guarantees provided by this + * timeout are as follows: * - Timeout will never be occur before watermark has exceeded the set timeout. * - Similar to processing time timeouts, there is a no strict upper bound on the delay when * the timeout actually occurs. The watermark can advance only when there is data in the * stream, and the event time of the data has actually advanced. - * - When the timeout occurs for a key, the function is called for that key with no values, and - * `KeyedState.hasTimedOut()` set to true. - * - The timeout is reset for key every time the function is called on the key, that is, - * when the key has new data, or the key has timed out. So the user has to set the timeout + * - When the timeout occurs for a group, the function is called for that group with no values, and + * `GroupState.hasTimedOut()` set to true. + * - The timeout is reset every time the function is called on a group, that is, + * when the group has new data, or the group has timed out. So the user has to set the timeout * duration every time the function is called, otherwise there will not be any timeout set. * - * Scala example of using KeyedState in `mapGroupsWithState`: + * Scala example of using GroupState in `mapGroupsWithState`: * {{{ * // A mapping function that maintains an integer state for string keys and returns a string. * // Additionally, it sets a timeout to remove the state if it has not received data for an hour. - * def mappingFunction(key: String, value: Iterator[Int], state: KeyedState[Int]): String = { + * def mappingFunction(key: String, value: Iterator[Int], state: GroupState[Int]): String = { * * if (state.hasTimedOut) { // If called when timing out, remove the state * state.remove() @@ -133,10 +133,10 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalKeyedState * * dataset * .groupByKey(...) - * .mapGroupsWithState(KeyedStateTimeout.ProcessingTimeTimeout)(mappingFunction) + * .mapGroupsWithState(GroupStateTimeout.ProcessingTimeTimeout)(mappingFunction) * }}} * - * Java example of using `KeyedState`: + * Java example of using `GroupState`: * {{{ * // A mapping function that maintains an integer state for string keys and returns a string. * // Additionally, it sets a timeout to remove the state if it has not received data for an hour. @@ -144,7 +144,7 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalKeyedState * new MapGroupsWithStateFunction() { * * @Override - * public String call(String key, Iterator value, KeyedState state) { + * public String call(String key, Iterator value, GroupState state) { * if (state.hasTimedOut()) { // If called when timing out, remove the state * state.remove(); * @@ -173,16 +173,16 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalKeyedState * dataset * .groupByKey(...) * .mapGroupsWithState( - * mappingFunction, Encoders.INT, Encoders.STRING, KeyedStateTimeout.ProcessingTimeTimeout); + * mappingFunction, Encoders.INT, Encoders.STRING, GroupStateTimeout.ProcessingTimeTimeout); * }}} * - * @tparam S User-defined type of the state to be stored for each key. Must be encodable into + * @tparam S User-defined type of the state to be stored for each group. Must be encodable into * Spark SQL types (see [[Encoder]] for more details). * @since 2.2.0 */ @Experimental @InterfaceStability.Evolving -trait KeyedState[S] extends LogicalKeyedState[S] { +trait GroupState[S] extends LogicalGroupState[S] { /** Whether state exists or not. */ def exists: Boolean @@ -201,7 +201,7 @@ trait KeyedState[S] extends LogicalKeyedState[S] { @throws[IllegalArgumentException]("when updating with null") def update(newState: S): Unit - /** Remove this keyed state. Note that this resets any timeout configuration as well. */ + /** Remove this state. Note that this resets any timeout configuration as well. */ def remove(): Unit /** diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index ffb4c6273ff8..78cf033dd81d 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -23,7 +23,7 @@ import java.sql.Timestamp; import java.util.*; -import org.apache.spark.sql.streaming.KeyedStateTimeout; +import org.apache.spark.sql.streaming.GroupStateTimeout; import org.apache.spark.sql.streaming.OutputMode; import scala.Tuple2; import scala.Tuple3; @@ -210,7 +210,7 @@ public void testGroupBy() { OutputMode.Append(), Encoders.LONG(), Encoders.STRING(), - KeyedStateTimeout.NoTimeout()); + GroupStateTimeout.NoTimeout()); Assert.assertEquals(asSet("1a", "3foobar"), toSet(flatMapped2.collectAsList())); diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala index fe72283bb608..3dabef6a9a35 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.plans.logical.FlatMapGroupsWithState import org.apache.spark.sql.catalyst.plans.physical.UnknownPartitioning import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ import org.apache.spark.sql.execution.RDDScanExec -import org.apache.spark.sql.execution.streaming.{FlatMapGroupsWithStateExec, KeyedStateImpl, MemoryStream} +import org.apache.spark.sql.execution.streaming.{FlatMapGroupsWithStateExec, GroupStateImpl, MemoryStream} import org.apache.spark.sql.execution.streaming.state.{StateStore, StateStoreId, StoreUpdate} import org.apache.spark.sql.streaming.FlatMapGroupsWithStateSuite.MemoryStateStore import org.apache.spark.sql.types.{DataType, IntegerType} @@ -43,16 +43,16 @@ case class Result(key: Long, count: Int) class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAfterAll { import testImplicits._ - import KeyedStateImpl._ - import KeyedStateTimeout._ + import GroupStateImpl._ + import GroupStateTimeout._ override def afterAll(): Unit = { super.afterAll() StateStore.stop() } - test("KeyedState - get, exists, update, remove") { - var state: KeyedStateImpl[String] = null + test("GroupState - get, exists, update, remove") { + var state: GroupStateImpl[String] = null def testState( expectedData: Option[String], @@ -73,13 +73,13 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf } // Updating empty state - state = new KeyedStateImpl[String](None) + state = new GroupStateImpl[String](None) testState(None) state.update("") testState(Some(""), shouldBeUpdated = true) // Updating exiting state - state = new KeyedStateImpl[String](Some("2")) + state = new GroupStateImpl[String](Some("2")) testState(Some("2")) state.update("3") testState(Some("3"), shouldBeUpdated = true) @@ -97,19 +97,19 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf } } - test("KeyedState - setTimeout**** with NoTimeout") { + test("GroupState - setTimeout**** with NoTimeout") { for (initState <- Seq(None, Some(5))) { // for different initial state - implicit val state = new KeyedStateImpl(initState, 1000, 1000, NoTimeout, hasTimedOut = false) + implicit val state = new GroupStateImpl(initState, 1000, 1000, NoTimeout, hasTimedOut = false) testTimeoutDurationNotAllowed[UnsupportedOperationException](state) testTimeoutTimestampNotAllowed[UnsupportedOperationException](state) } } - test("KeyedState - setTimeout**** with ProcessingTimeTimeout") { - implicit var state: KeyedStateImpl[Int] = null + test("GroupState - setTimeout**** with ProcessingTimeTimeout") { + implicit var state: GroupStateImpl[Int] = null - state = new KeyedStateImpl[Int](None, 1000, 1000, ProcessingTimeTimeout, hasTimedOut = false) + state = new GroupStateImpl[Int](None, 1000, 1000, ProcessingTimeTimeout, hasTimedOut = false) assert(state.getTimeoutTimestamp === NO_TIMESTAMP) testTimeoutDurationNotAllowed[IllegalStateException](state) testTimeoutTimestampNotAllowed[UnsupportedOperationException](state) @@ -128,8 +128,8 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf testTimeoutTimestampNotAllowed[UnsupportedOperationException](state) } - test("KeyedState - setTimeout**** with EventTimeTimeout") { - implicit val state = new KeyedStateImpl[Int]( + test("GroupState - setTimeout**** with EventTimeTimeout") { + implicit val state = new GroupStateImpl[Int]( None, 1000, 1000, EventTimeTimeout, hasTimedOut = false) assert(state.getTimeoutTimestamp === NO_TIMESTAMP) testTimeoutDurationNotAllowed[UnsupportedOperationException](state) @@ -148,8 +148,8 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf testTimeoutTimestampNotAllowed[IllegalStateException](state) } - test("KeyedState - illegal params to setTimeout****") { - var state: KeyedStateImpl[Int] = null + test("GroupState - illegal params to setTimeout****") { + var state: GroupStateImpl[Int] = null // Test setTimeout****() with illegal values def testIllegalTimeout(body: => Unit): Unit = { @@ -157,14 +157,14 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf assert(state.getTimeoutTimestamp === NO_TIMESTAMP) } - state = new KeyedStateImpl(Some(5), 1000, 1000, ProcessingTimeTimeout, hasTimedOut = false) + state = new GroupStateImpl(Some(5), 1000, 1000, ProcessingTimeTimeout, hasTimedOut = false) testIllegalTimeout { state.setTimeoutDuration(-1000) } testIllegalTimeout { state.setTimeoutDuration(0) } testIllegalTimeout { state.setTimeoutDuration("-2 second") } testIllegalTimeout { state.setTimeoutDuration("-1 month") } testIllegalTimeout { state.setTimeoutDuration("1 month -1 day") } - state = new KeyedStateImpl(Some(5), 1000, 1000, EventTimeTimeout, hasTimedOut = false) + state = new GroupStateImpl(Some(5), 1000, 1000, EventTimeTimeout, hasTimedOut = false) testIllegalTimeout { state.setTimeoutTimestamp(-10000) } testIllegalTimeout { state.setTimeoutTimestamp(10000, "-3 second") } testIllegalTimeout { state.setTimeoutTimestamp(10000, "-1 month") } @@ -175,25 +175,25 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf testIllegalTimeout { state.setTimeoutTimestamp(new Date(-10000), "1 month -1 day") } } - test("KeyedState - hasTimedOut") { + test("GroupState - hasTimedOut") { for (timeoutConf <- Seq(NoTimeout, ProcessingTimeTimeout, EventTimeTimeout)) { for (initState <- Seq(None, Some(5))) { - val state1 = new KeyedStateImpl(initState, 1000, 1000, timeoutConf, hasTimedOut = false) + val state1 = new GroupStateImpl(initState, 1000, 1000, timeoutConf, hasTimedOut = false) assert(state1.hasTimedOut === false) - val state2 = new KeyedStateImpl(initState, 1000, 1000, timeoutConf, hasTimedOut = true) + val state2 = new GroupStateImpl(initState, 1000, 1000, timeoutConf, hasTimedOut = true) assert(state2.hasTimedOut === true) } } } - test("KeyedState - primitive type") { - var intState = new KeyedStateImpl[Int](None) + test("GroupState - primitive type") { + var intState = new GroupStateImpl[Int](None) intercept[NoSuchElementException] { intState.get } assert(intState.getOption === None) - intState = new KeyedStateImpl[Int](Some(10)) + intState = new GroupStateImpl[Int](Some(10)) assert(intState.get == 10) intState.update(0) assert(intState.get == 0) @@ -218,21 +218,21 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf testStateUpdateWithData( testName + "no update", stateUpdates = state => { /* do nothing */ }, - timeoutConf = KeyedStateTimeout.NoTimeout, + timeoutConf = GroupStateTimeout.NoTimeout, priorState = priorState, expectedState = priorState) // should not change testStateUpdateWithData( testName + "state updated", stateUpdates = state => { state.update(5) }, - timeoutConf = KeyedStateTimeout.NoTimeout, + timeoutConf = GroupStateTimeout.NoTimeout, priorState = priorState, expectedState = Some(5)) // should change testStateUpdateWithData( testName + "state removed", stateUpdates = state => { state.remove() }, - timeoutConf = KeyedStateTimeout.NoTimeout, + timeoutConf = GroupStateTimeout.NoTimeout, priorState = priorState, expectedState = None) // should be removed } @@ -283,7 +283,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf testStateUpdateWithData( s"ProcessingTimeTimeout - $testName - state and timeout duration updated", stateUpdates = - (state: KeyedState[Int]) => { state.update(5); state.setTimeoutDuration(5000) }, + (state: GroupState[Int]) => { state.update(5); state.setTimeoutDuration(5000) }, timeoutConf = ProcessingTimeTimeout, priorState = priorState, priorTimeoutTimestamp = priorTimeoutTimestamp, @@ -293,7 +293,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf testStateUpdateWithData( s"EventTimeTimeout - $testName - state and timeout timestamp updated", stateUpdates = - (state: KeyedState[Int]) => { state.update(5); state.setTimeoutTimestamp(5000) }, + (state: GroupState[Int]) => { state.update(5); state.setTimeoutTimestamp(5000) }, timeoutConf = EventTimeTimeout, priorState = priorState, priorTimeoutTimestamp = priorTimeoutTimestamp, @@ -303,7 +303,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf testStateUpdateWithData( s"EventTimeTimeout - $testName - timeout timestamp updated to before watermark", stateUpdates = - (state: KeyedState[Int]) => { + (state: GroupState[Int]) => { state.update(5) intercept[IllegalArgumentException] { state.setTimeoutTimestamp(currentBatchWatermark - 1) // try to set to < watermark @@ -387,7 +387,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf test("StateStoreUpdater - rows are cloned before writing to StateStore") { // function for running count - val func = (key: Int, values: Iterator[Int], state: KeyedState[Int]) => { + val func = (key: Int, values: Iterator[Int], state: GroupState[Int]) => { state.update(state.getOption.getOrElse(0) + values.size) Iterator.empty } @@ -404,7 +404,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf test("flatMapGroupsWithState - streaming") { // Function to maintain running count up to 2, and then remove the count // Returns the data and the count if state is defined, otherwise does not return anything - val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => { + val stateFunc = (key: String, values: Iterator[String], state: GroupState[RunningCount]) => { val count = state.getOption.map(_.count).getOrElse(0L) + values.size if (count == 3) { @@ -420,7 +420,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf val result = inputData.toDS() .groupByKey(x => x) - .flatMapGroupsWithState(Update, KeyedStateTimeout.NoTimeout)(stateFunc) + .flatMapGroupsWithState(Update, GroupStateTimeout.NoTimeout)(stateFunc) testStream(result, Update)( AddData(inputData, "a"), @@ -446,7 +446,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf // Function to maintain running count up to 2, and then remove the count // Returns the data and the count if state is defined, otherwise does not return anything // Additionally, it updates state lazily as the returned iterator get consumed - val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => { + val stateFunc = (key: String, values: Iterator[String], state: GroupState[RunningCount]) => { values.flatMap { _ => val count = state.getOption.map(_.count).getOrElse(0L) + 1 if (count == 3) { @@ -463,7 +463,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf val result = inputData.toDS() .groupByKey(x => x) - .flatMapGroupsWithState(Update, KeyedStateTimeout.NoTimeout)(stateFunc) + .flatMapGroupsWithState(Update, GroupStateTimeout.NoTimeout)(stateFunc) testStream(result, Update)( AddData(inputData, "a", "a", "b"), CheckLastBatch(("a", "1"), ("a", "2"), ("b", "1")), @@ -481,7 +481,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf test("flatMapGroupsWithState - streaming + aggregation") { // Function to maintain running count up to 2, and then remove the count // Returns the data and the count (-1 if count reached beyond 2 and state was just removed) - val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => { + val stateFunc = (key: String, values: Iterator[String], state: GroupState[RunningCount]) => { val count = state.getOption.map(_.count).getOrElse(0L) + values.size if (count == 3) { @@ -497,7 +497,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf val result = inputData.toDS() .groupByKey(x => x) - .flatMapGroupsWithState(Append, KeyedStateTimeout.NoTimeout)(stateFunc) + .flatMapGroupsWithState(Append, GroupStateTimeout.NoTimeout)(stateFunc) .groupByKey(_._1) .count() @@ -524,20 +524,20 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf test("flatMapGroupsWithState - batch") { // Function that returns running count only if its even, otherwise does not return - val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => { + val stateFunc = (key: String, values: Iterator[String], state: GroupState[RunningCount]) => { if (state.exists) throw new IllegalArgumentException("state.exists should be false") Iterator((key, values.size)) } val df = Seq("a", "a", "b").toDS .groupByKey(x => x) - .flatMapGroupsWithState(Update, KeyedStateTimeout.NoTimeout)(stateFunc).toDF + .flatMapGroupsWithState(Update, GroupStateTimeout.NoTimeout)(stateFunc).toDF checkAnswer(df, Seq(("a", 2), ("b", 1)).toDF) } test("flatMapGroupsWithState - streaming with processing time timeout") { // Function to maintain running count up to 2, and then remove the count // Returns the data and the count (-1 if count reached beyond 2 and state was just removed) - val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => { + val stateFunc = (key: String, values: Iterator[String], state: GroupState[RunningCount]) => { if (state.hasTimedOut) { state.remove() Iterator((key, "-1")) @@ -594,7 +594,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf val stateFunc = ( key: String, values: Iterator[(String, Long)], - state: KeyedState[Long]) => { + state: GroupState[Long]) => { val timeoutDelay = 5 if (key != "a") { Iterator.empty @@ -637,7 +637,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf test("mapGroupsWithState - streaming") { // Function to maintain running count up to 2, and then remove the count // Returns the data and the count (-1 if count reached beyond 2 and state was just removed) - val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => { + val stateFunc = (key: String, values: Iterator[String], state: GroupState[RunningCount]) => { val count = state.getOption.map(_.count).getOrElse(0L) + values.size if (count == 3) { @@ -676,7 +676,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf } test("mapGroupsWithState - batch") { - val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => { + val stateFunc = (key: String, values: Iterator[String], state: GroupState[RunningCount]) => { if (state.exists) throw new IllegalArgumentException("state.exists should be false") (key, values.size) } @@ -690,7 +690,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf } testQuietly("StateStore.abort on task failure handling") { - val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => { + val stateFunc = (key: String, values: Iterator[String], state: GroupState[RunningCount]) => { if (FlatMapGroupsWithStateSuite.failInTask) throw new Exception("expected failure") val count = state.getOption.map(_.count).getOrElse(0L) + values.size state.update(RunningCount(count)) @@ -724,7 +724,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf } test("output partitioning is unknown") { - val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => key + val stateFunc = (key: String, values: Iterator[String], state: GroupState[RunningCount]) => key val inputData = MemoryStream[String] val result = inputData.toDS.groupByKey(x => x).mapGroupsWithState(stateFunc) testStream(result, Update)( @@ -735,13 +735,13 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf } test("disallow complete mode") { - val stateFunc = (key: String, values: Iterator[String], state: KeyedState[Int]) => { + val stateFunc = (key: String, values: Iterator[String], state: GroupState[Int]) => { Iterator[String]() } var e = intercept[IllegalArgumentException] { MemoryStream[String].toDS().groupByKey(x => x).flatMapGroupsWithState( - OutputMode.Complete, KeyedStateTimeout.NoTimeout)(stateFunc) + OutputMode.Complete, GroupStateTimeout.NoTimeout)(stateFunc) } assert(e.getMessage === "The output mode of function should be append or update") @@ -750,20 +750,20 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf override def call( key: String, values: JIterator[String], - state: KeyedState[Int]): JIterator[String] = { null } + state: GroupState[Int]): JIterator[String] = { null } } e = intercept[IllegalArgumentException] { MemoryStream[String].toDS().groupByKey(x => x).flatMapGroupsWithState( javaStateFunc, OutputMode.Complete, - implicitly[Encoder[Int]], implicitly[Encoder[String]], KeyedStateTimeout.NoTimeout) + implicitly[Encoder[Int]], implicitly[Encoder[String]], GroupStateTimeout.NoTimeout) } assert(e.getMessage === "The output mode of function should be append or update") } def testStateUpdateWithData( testName: String, - stateUpdates: KeyedState[Int] => Unit, - timeoutConf: KeyedStateTimeout, + stateUpdates: GroupState[Int] => Unit, + timeoutConf: GroupStateTimeout, priorState: Option[Int], priorTimeoutTimestamp: Long = NO_TIMESTAMP, expectedState: Option[Int] = None, @@ -773,7 +773,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf return // there can be no prior timestamp, when there is no prior state } test(s"StateStoreUpdater - updates with data - $testName") { - val mapGroupsFunc = (key: Int, values: Iterator[Int], state: KeyedState[Int]) => { + val mapGroupsFunc = (key: Int, values: Iterator[Int], state: GroupState[Int]) => { assert(state.hasTimedOut === false, "hasTimedOut not false") assert(values.nonEmpty, "Some value is expected") stateUpdates(state) @@ -787,14 +787,14 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf def testStateUpdateWithTimeout( testName: String, - stateUpdates: KeyedState[Int] => Unit, - timeoutConf: KeyedStateTimeout, + stateUpdates: GroupState[Int] => Unit, + timeoutConf: GroupStateTimeout, priorTimeoutTimestamp: Long, expectedState: Option[Int], expectedTimeoutTimestamp: Long = NO_TIMESTAMP): Unit = { test(s"StateStoreUpdater - updates for timeout - $testName") { - val mapGroupsFunc = (key: Int, values: Iterator[Int], state: KeyedState[Int]) => { + val mapGroupsFunc = (key: Int, values: Iterator[Int], state: GroupState[Int]) => { assert(state.hasTimedOut === true, "hasTimedOut not true") assert(values.isEmpty, "values not empty") stateUpdates(state) @@ -808,8 +808,8 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf def testStateUpdate( testTimeoutUpdates: Boolean, - mapGroupsFunc: (Int, Iterator[Int], KeyedState[Int]) => Iterator[Int], - timeoutConf: KeyedStateTimeout, + mapGroupsFunc: (Int, Iterator[Int], GroupState[Int]) => Iterator[Int], + timeoutConf: GroupStateTimeout, priorState: Option[Int], priorTimeoutTimestamp: Long, expectedState: Option[Int], @@ -848,8 +848,8 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf } def newFlatMapGroupsWithStateExec( - func: (Int, Iterator[Int], KeyedState[Int]) => Iterator[Int], - timeoutType: KeyedStateTimeout = KeyedStateTimeout.NoTimeout, + func: (Int, Iterator[Int], GroupState[Int]) => Iterator[Int], + timeoutType: GroupStateTimeout = GroupStateTimeout.NoTimeout, batchTimestampMs: Long = NO_TIMESTAMP): FlatMapGroupsWithStateExec = { MemoryStream[Int] .toDS @@ -863,7 +863,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf }.get } - def testTimeoutDurationNotAllowed[T <: Exception: Manifest](state: KeyedStateImpl[_]): Unit = { + def testTimeoutDurationNotAllowed[T <: Exception: Manifest](state: GroupStateImpl[_]): Unit = { val prevTimestamp = state.getTimeoutTimestamp intercept[T] { state.setTimeoutDuration(1000) } assert(state.getTimeoutTimestamp === prevTimestamp) @@ -871,7 +871,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf assert(state.getTimeoutTimestamp === prevTimestamp) } - def testTimeoutTimestampNotAllowed[T <: Exception: Manifest](state: KeyedStateImpl[_]): Unit = { + def testTimeoutTimestampNotAllowed[T <: Exception: Manifest](state: GroupStateImpl[_]): Unit = { val prevTimestamp = state.getTimeoutTimestamp intercept[T] { state.setTimeoutTimestamp(2000) } assert(state.getTimeoutTimestamp === prevTimestamp) From 12cd00706cbfff4c8ac681fcae65b4c4c8751877 Mon Sep 17 00:00:00 2001 From: Sameer Agarwal Date: Wed, 22 Mar 2017 15:58:42 -0700 Subject: [PATCH 106/512] [BUILD][MINOR] Fix 2.10 build ## What changes were proposed in this pull request? https://github.com/apache/spark/pull/17385 breaks the 2.10 sbt/maven builds by hitting an empty-string interpolation bug (https://issues.scala-lang.org/browse/SI-7919). https://amplab.cs.berkeley.edu/jenkins/view/Spark%20QA%20Compile/job/spark-master-compile-sbt-scala-2.10/4072/ https://amplab.cs.berkeley.edu/jenkins/view/Spark%20QA%20Compile/job/spark-master-compile-maven-scala-2.10/3987/ ## How was this patch tested? Compiles Author: Sameer Agarwal Closes #17391 from sameeragarwal/build-fix. --- .../spark/sql/streaming/FlatMapGroupsWithStateSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala index 3dabef6a9a35..89a25973afdd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala @@ -240,7 +240,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf // Tests for StateStoreUpdater.updateStateForKeysWithData() when timeout != NoTimeout for (priorState <- Seq(None, Some(0))) { for (priorTimeoutTimestamp <- Seq(NO_TIMESTAMP, 1000)) { - var testName = s"" + var testName = "" if (priorState.nonEmpty) { testName += "prior state set, " if (priorTimeoutTimestamp == 1000) { From 07c12c09a75645f6b56b30654455b3838b7b6637 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Thu, 23 Mar 2017 00:25:01 -0700 Subject: [PATCH 107/512] [SPARK-18579][SQL] Use ignoreLeadingWhiteSpace and ignoreTrailingWhiteSpace options in CSV writing ## What changes were proposed in this pull request? This PR proposes to support _not_ trimming the white spaces when writing out. These are `false` by default in CSV reading path but these are `true` by default in CSV writing in univocity parser. Both `ignoreLeadingWhiteSpace` and `ignoreTrailingWhiteSpace` options are not being used for writing and therefore, we are always trimming the white spaces. It seems we should provide a way to keep this white spaces easily. WIth the data below: ```scala val df = spark.read.csv(Seq("a , b , c").toDS) df.show() ``` ``` +---+----+---+ |_c0| _c1|_c2| +---+----+---+ | a | b | c| +---+----+---+ ``` **Before** ```scala df.write.csv("/tmp/text.csv") spark.read.text("/tmp/text.csv").show() ``` ``` +-----+ |value| +-----+ |a,b,c| +-----+ ``` It seems this can't be worked around via `quoteAll` too. ```scala df.write.option("quoteAll", true).csv("/tmp/text.csv") spark.read.text("/tmp/text.csv").show() ``` ``` +-----------+ | value| +-----------+ |"a","b","c"| +-----------+ ``` **After** ```scala df.write.option("ignoreLeadingWhiteSpace", false).option("ignoreTrailingWhiteSpace", false).csv("/tmp/text.csv") spark.read.text("/tmp/text.csv").show() ``` ``` +----------+ | value| +----------+ |a , b , c| +----------+ ``` Note that this case is possible in R ```r > system("cat text.csv") f1,f2,f3 a , b , c > df <- read.csv(file="text.csv") > df f1 f2 f3 1 a b c > write.csv(df, file="text1.csv", quote=F, row.names=F) > system("cat text1.csv") f1,f2,f3 a , b , c ``` ## How was this patch tested? Unit tests in `CSVSuite` and manual tests for Python. Author: hyukjinkwon Closes #17310 from HyukjinKwon/SPARK-18579. --- python/pyspark/sql/readwriter.py | 28 +++++---- python/pyspark/sql/streaming.py | 12 ++-- python/pyspark/sql/tests.py | 13 +++++ .../apache/spark/sql/DataFrameReader.scala | 6 +- .../apache/spark/sql/DataFrameWriter.scala | 6 +- .../datasources/csv/CSVOptions.scala | 15 +++-- .../sql/streaming/DataStreamReader.scala | 6 +- .../execution/datasources/csv/CSVSuite.scala | 57 +++++++++++++++++++ 8 files changed, 116 insertions(+), 27 deletions(-) diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 759c27507c39..5e732b4bec8f 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -341,12 +341,12 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non default value, ``false``. :param inferSchema: infers the input schema automatically from data. It requires one extra pass over the data. If None is set, it uses the default value, ``false``. - :param ignoreLeadingWhiteSpace: defines whether or not leading whitespaces from values - being read should be skipped. If None is set, it uses - the default value, ``false``. - :param ignoreTrailingWhiteSpace: defines whether or not trailing whitespaces from values - being read should be skipped. If None is set, it uses - the default value, ``false``. + :param ignoreLeadingWhiteSpace: A flag indicating whether or not leading whitespaces from + values being read should be skipped. If None is set, it + uses the default value, ``false``. + :param ignoreTrailingWhiteSpace: A flag indicating whether or not trailing whitespaces from + values being read should be skipped. If None is set, it + uses the default value, ``false``. :param nullValue: sets the string representation of a null value. If None is set, it uses the default value, empty string. Since 2.0.1, this ``nullValue`` param applies to all supported types including the string type. @@ -706,7 +706,7 @@ def text(self, path, compression=None): @since(2.0) def csv(self, path, mode=None, compression=None, sep=None, quote=None, escape=None, header=None, nullValue=None, escapeQuotes=None, quoteAll=None, dateFormat=None, - timestampFormat=None): + timestampFormat=None, ignoreLeadingWhiteSpace=None, ignoreTrailingWhiteSpace=None): """Saves the content of the :class:`DataFrame` in CSV format at the specified path. :param path: the path in any Hadoop supported file system @@ -728,10 +728,10 @@ def csv(self, path, mode=None, compression=None, sep=None, quote=None, escape=No empty string. :param escape: sets the single character used for escaping quotes inside an already quoted value. If None is set, it uses the default value, ``\`` - :param escapeQuotes: A flag indicating whether values containing quotes should always + :param escapeQuotes: a flag indicating whether values containing quotes should always be enclosed in quotes. If None is set, it uses the default value ``true``, escaping all values containing a quote character. - :param quoteAll: A flag indicating whether all values should always be enclosed in + :param quoteAll: a flag indicating whether all values should always be enclosed in quotes. If None is set, it uses the default value ``false``, only escaping values containing a quote character. :param header: writes the names of columns as the first line. If None is set, it uses @@ -746,13 +746,21 @@ def csv(self, path, mode=None, compression=None, sep=None, quote=None, escape=No formats follow the formats at ``java.text.SimpleDateFormat``. This applies to timestamp type. If None is set, it uses the default value, ``yyyy-MM-dd'T'HH:mm:ss.SSSZZ``. + :param ignoreLeadingWhiteSpace: a flag indicating whether or not leading whitespaces from + values being written should be skipped. If None is set, it + uses the default value, ``true``. + :param ignoreTrailingWhiteSpace: a flag indicating whether or not trailing whitespaces from + values being written should be skipped. If None is set, it + uses the default value, ``true``. >>> df.write.csv(os.path.join(tempfile.mkdtemp(), 'data')) """ self.mode(mode) self._set_opts(compression=compression, sep=sep, quote=quote, escape=escape, header=header, nullValue=nullValue, escapeQuotes=escapeQuotes, quoteAll=quoteAll, - dateFormat=dateFormat, timestampFormat=timestampFormat) + dateFormat=dateFormat, timestampFormat=timestampFormat, + ignoreLeadingWhiteSpace=ignoreLeadingWhiteSpace, + ignoreTrailingWhiteSpace=ignoreTrailingWhiteSpace) self._jwrite.csv(path) @since(1.5) diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index e227f9ceb576..80f4340cdf13 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -597,12 +597,12 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non default value, ``false``. :param inferSchema: infers the input schema automatically from data. It requires one extra pass over the data. If None is set, it uses the default value, ``false``. - :param ignoreLeadingWhiteSpace: defines whether or not leading whitespaces from values - being read should be skipped. If None is set, it uses - the default value, ``false``. - :param ignoreTrailingWhiteSpace: defines whether or not trailing whitespaces from values - being read should be skipped. If None is set, it uses - the default value, ``false``. + :param ignoreLeadingWhiteSpace: a flag indicating whether or not leading whitespaces from + values being read should be skipped. If None is set, it + uses the default value, ``false``. + :param ignoreTrailingWhiteSpace: a flag indicating whether or not trailing whitespaces from + values being read should be skipped. If None is set, it + uses the default value, ``false``. :param nullValue: sets the string representation of a null value. If None is set, it uses the default value, empty string. Since 2.0.1, this ``nullValue`` param applies to all supported types including the string type. diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index f0a9a0400e39..29d613bc5fe3 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -450,6 +450,19 @@ def test_wholefile_csv(self): Row(_c0=u'Hyukjin', _c1=u'25', _c2=u'I am Hyukjin\n\nI love Spark!')] self.assertEqual(ages_newlines.collect(), expected) + def test_ignorewhitespace_csv(self): + tmpPath = tempfile.mkdtemp() + shutil.rmtree(tmpPath) + self.spark.createDataFrame([[" a", "b ", " c "]]).write.csv( + tmpPath, + ignoreLeadingWhiteSpace=False, + ignoreTrailingWhiteSpace=False) + + expected = [Row(value=u' a,b , c ')] + readback = self.spark.read.text(tmpPath) + self.assertEqual(readback.collect(), expected) + shutil.rmtree(tmpPath) + def test_read_multiple_orc_file(self): df = self.spark.read.orc(["python/test_support/sql/orc_partitioned/b=0/c=0", "python/test_support/sql/orc_partitioned/b=1/c=1"]) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index e39b4d91f1f6..e6d2b1bc28d9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -489,9 +489,9 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { *
      • `header` (default `false`): uses the first line as names of columns.
      • *
      • `inferSchema` (default `false`): infers the input schema automatically from data. It * requires one extra pass over the data.
      • - *
      • `ignoreLeadingWhiteSpace` (default `false`): defines whether or not leading whitespaces - * from values being read should be skipped.
      • - *
      • `ignoreTrailingWhiteSpace` (default `false`): defines whether or not trailing + *
      • `ignoreLeadingWhiteSpace` (default `false`): a flag indicating whether or not leading + * whitespaces from values being read should be skipped.
      • + *
      • `ignoreTrailingWhiteSpace` (default `false`): a flag indicating whether or not trailing * whitespaces from values being read should be skipped.
      • *
      • `nullValue` (default empty string): sets the string representation of a null value. Since * 2.0.1, this applies to all supported types including the string type.
      • diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 3e975ef6a3c2..e973d0bc6d09 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -573,7 +573,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { *
      • `escapeQuotes` (default `true`): a flag indicating whether values containing * quotes should always be enclosed in quotes. Default is to escape all values containing * a quote character.
      • - *
      • `quoteAll` (default `false`): A flag indicating whether all values should always be + *
      • `quoteAll` (default `false`): a flag indicating whether all values should always be * enclosed in quotes. Default is to only escape values containing a quote character.
      • *
      • `header` (default `false`): writes the names of columns as the first line.
      • *
      • `nullValue` (default empty string): sets the string representation of a null value.
      • @@ -586,6 +586,10 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { *
      • `timestampFormat` (default `yyyy-MM-dd'T'HH:mm:ss.SSSZZ`): sets the string that * indicates a timestamp format. Custom date formats follow the formats at * `java.text.SimpleDateFormat`. This applies to timestamp type.
      • + *
      • `ignoreLeadingWhiteSpace` (default `true`): a flag indicating whether or not leading + * whitespaces from values being written should be skipped.
      • + *
      • `ignoreTrailingWhiteSpace` (default `true`): a flag indicating defines whether or not + * trailing whitespaces from values being written should be skipped.
      • *
      * * @since 2.0.0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala index 5d2c23ed9618..e7b79e0cbfd1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala @@ -93,8 +93,13 @@ class CSVOptions( val headerFlag = getBool("header") val inferSchemaFlag = getBool("inferSchema") - val ignoreLeadingWhiteSpaceFlag = getBool("ignoreLeadingWhiteSpace") - val ignoreTrailingWhiteSpaceFlag = getBool("ignoreTrailingWhiteSpace") + val ignoreLeadingWhiteSpaceInRead = getBool("ignoreLeadingWhiteSpace", default = false) + val ignoreTrailingWhiteSpaceInRead = getBool("ignoreTrailingWhiteSpace", default = false) + + // For write, both options were `true` by default. We leave it as `true` for + // backwards compatibility. + val ignoreLeadingWhiteSpaceFlagInWrite = getBool("ignoreLeadingWhiteSpace", default = true) + val ignoreTrailingWhiteSpaceFlagInWrite = getBool("ignoreTrailingWhiteSpace", default = true) val columnNameOfCorruptRecord = parameters.getOrElse("columnNameOfCorruptRecord", defaultColumnNameOfCorruptRecord) @@ -144,6 +149,8 @@ class CSVOptions( format.setQuote(quote) format.setQuoteEscape(escape) format.setComment(comment) + writerSettings.setIgnoreLeadingWhitespaces(ignoreLeadingWhiteSpaceFlagInWrite) + writerSettings.setIgnoreTrailingWhitespaces(ignoreTrailingWhiteSpaceFlagInWrite) writerSettings.setNullValue(nullValue) writerSettings.setEmptyValue(nullValue) writerSettings.setSkipEmptyLines(true) @@ -159,8 +166,8 @@ class CSVOptions( format.setQuote(quote) format.setQuoteEscape(escape) format.setComment(comment) - settings.setIgnoreLeadingWhitespaces(ignoreLeadingWhiteSpaceFlag) - settings.setIgnoreTrailingWhitespaces(ignoreTrailingWhiteSpaceFlag) + settings.setIgnoreLeadingWhitespaces(ignoreLeadingWhiteSpaceInRead) + settings.setIgnoreTrailingWhitespaces(ignoreTrailingWhiteSpaceInRead) settings.setReadInputOnSeparateThread(false) settings.setInputBufferSize(inputBufferSize) settings.setMaxColumns(maxColumns) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index f6e2fef74b8d..997ca286597d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -238,9 +238,9 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo *
    • `header` (default `false`): uses the first line as names of columns.
    • *
    • `inferSchema` (default `false`): infers the input schema automatically from data. It * requires one extra pass over the data.
    • - *
    • `ignoreLeadingWhiteSpace` (default `false`): defines whether or not leading whitespaces - * from values being read should be skipped.
    • - *
    • `ignoreTrailingWhiteSpace` (default `false`): defines whether or not trailing + *
    • `ignoreLeadingWhiteSpace` (default `false`): a flag indicating whether or not leading + * whitespaces from values being read should be skipped.
    • + *
    • `ignoreTrailingWhiteSpace` (default `false`): a flag indicating whether or not trailing * whitespaces from values being read should be skipped.
    • *
    • `nullValue` (default empty string): sets the string representation of a null value. Since * 2.0.1, this applies to all supported types including the string type.
    • diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index 2600894ca303..d70c47f4e237 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -1117,4 +1117,61 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { assert(df2.schema === schema) } + test("ignoreLeadingWhiteSpace and ignoreTrailingWhiteSpace options - read") { + val input = " a,b , c " + + // For reading, default of both `ignoreLeadingWhiteSpace` and`ignoreTrailingWhiteSpace` + // are `false`. So, these are excluded. + val combinations = Seq( + (true, true), + (false, true), + (true, false)) + + // Check if read rows ignore whitespaces as configured. + val expectedRows = Seq( + Row("a", "b", "c"), + Row(" a", "b", " c"), + Row("a", "b ", "c ")) + + combinations.zip(expectedRows) + .foreach { case ((ignoreLeadingWhiteSpace, ignoreTrailingWhiteSpace), expected) => + val df = spark.read + .option("ignoreLeadingWhiteSpace", ignoreLeadingWhiteSpace) + .option("ignoreTrailingWhiteSpace", ignoreTrailingWhiteSpace) + .csv(Seq(input).toDS()) + + checkAnswer(df, expected) + } + } + + test("SPARK-18579: ignoreLeadingWhiteSpace and ignoreTrailingWhiteSpace options - write") { + val df = Seq((" a", "b ", " c ")).toDF() + + // For writing, default of both `ignoreLeadingWhiteSpace` and `ignoreTrailingWhiteSpace` + // are `true`. So, these are excluded. + val combinations = Seq( + (false, false), + (false, true), + (true, false)) + + // Check if written lines ignore each whitespaces as configured. + val expectedLines = Seq( + " a,b , c ", + " a,b, c", + "a,b ,c ") + + combinations.zip(expectedLines) + .foreach { case ((ignoreLeadingWhiteSpace, ignoreTrailingWhiteSpace), expected) => + withTempPath { path => + df.write + .option("ignoreLeadingWhiteSpace", ignoreLeadingWhiteSpace) + .option("ignoreTrailingWhiteSpace", ignoreTrailingWhiteSpace) + .csv(path.getAbsolutePath) + + // Read back the written lines. + val readBack = spark.read.text(path.getAbsolutePath) + checkAnswer(readBack, Row(expected)) + } + } + } } From aefe79890541bc0829f184e03eb3961739ca8ef2 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Thu, 23 Mar 2017 08:41:30 +0000 Subject: [PATCH 108/512] [MINOR][BUILD] Fix javadoc8 break ## What changes were proposed in this pull request? Several javadoc8 breaks have been introduced. This PR proposes fix those instances so that we can build Scala/Java API docs. ``` [error] .../spark/sql/core/target/java/org/apache/spark/sql/streaming/GroupState.java:6: error: reference not found [error] * flatMapGroupsWithState operations on {link KeyValueGroupedDataset}. [error] ^ [error] .../spark/sql/core/target/java/org/apache/spark/sql/streaming/GroupState.java:10: error: reference not found [error] * Both, mapGroupsWithState and flatMapGroupsWithState in {link KeyValueGroupedDataset} [error] ^ [error] .../spark/sql/core/target/java/org/apache/spark/sql/streaming/GroupState.java:51: error: reference not found [error] * {link GroupStateTimeout.ProcessingTimeTimeout}) or event time (i.e. [error] ^ [error] .../spark/sql/core/target/java/org/apache/spark/sql/streaming/GroupState.java:52: error: reference not found [error] * {link GroupStateTimeout.EventTimeTimeout}). [error] ^ [error] .../spark/sql/core/target/java/org/apache/spark/sql/streaming/GroupState.java:158: error: reference not found [error] * Spark SQL types (see {link Encoder} for more details). [error] ^ [error] .../spark/mllib/target/java/org/apache/spark/ml/fpm/FPGrowthParams.java:26: error: bad use of '>' [error] * Number of partitions (>=1) used by parallel FP-growth. By default the param is not set, and [error] ^ [error] .../spark/sql/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsWithStateFunction.java:30: error: reference not found [error] * {link org.apache.spark.sql.KeyValueGroupedDataset#flatMapGroupsWithState( [error] ^ [error] .../spark/sql/core/target/java/org/apache/spark/sql/KeyValueGroupedDataset.java:211: error: reference not found [error] * See {link GroupState} for more details. [error] ^ [error] .../spark/sql/core/target/java/org/apache/spark/sql/KeyValueGroupedDataset.java:232: error: reference not found [error] * See {link GroupState} for more details. [error] ^ [error] .../spark/sql/core/target/java/org/apache/spark/sql/KeyValueGroupedDataset.java:254: error: reference not found [error] * See {link GroupState} for more details. [error] ^ [error] .../spark/sql/core/target/java/org/apache/spark/sql/KeyValueGroupedDataset.java:277: error: reference not found [error] * See {link GroupState} for more details. [error] ^ [error] .../spark/core/target/java/org/apache/spark/TaskContextImpl.java:10: error: reference not found [error] * {link TaskMetrics} & {link MetricsSystem} objects are not thread safe. [error] ^ [error] .../spark/core/target/java/org/apache/spark/TaskContextImpl.java:10: error: reference not found [error] * {link TaskMetrics} & {link MetricsSystem} objects are not thread safe. [error] ^ [info] 13 errors ``` ``` jekyll 3.3.1 | Error: Unidoc generation failed ``` ## How was this patch tested? Manually via `jekyll build` Author: hyukjinkwon Closes #17389 from HyukjinKwon/minor-javadoc8-fix. --- .../org/apache/spark/TaskContextImpl.scala | 2 +- .../org/apache/spark/ml/fpm/FPGrowth.scala | 4 ++-- .../FlatMapGroupsWithStateFunction.java | 2 +- .../spark/sql/KeyValueGroupedDataset.scala | 8 +++---- .../spark/sql/streaming/GroupState.scala | 22 +++++++++---------- 5 files changed, 19 insertions(+), 19 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala index ea8dcdfd5d7d..f346cf8d6580 100644 --- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala +++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala @@ -38,7 +38,7 @@ import org.apache.spark.util._ * callbacks are protected by locking on the context instance. For instance, this ensures * that you cannot add a completion listener in one thread while we are completing (and calling * the completion listeners) in another thread. Other state is immutable, however the exposed - * [[TaskMetrics]] & [[MetricsSystem]] objects are not thread safe. + * `TaskMetrics` & `MetricsSystem` objects are not thread safe. */ private[spark] class TaskContextImpl( val stageId: Int, diff --git a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala index e2bc270b38da..65cc80619569 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala @@ -69,8 +69,8 @@ private[fpm] trait FPGrowthParams extends Params with HasPredictionCol { def getMinSupport: Double = $(minSupport) /** - * Number of partitions (>=1) used by parallel FP-growth. By default the param is not set, and - * partition number of the input dataset is used. + * Number of partitions (at least 1) used by parallel FP-growth. By default the param is not + * set, and partition number of the input dataset is used. * @group expertParam */ @Since("2.2.0") diff --git a/sql/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsWithStateFunction.java b/sql/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsWithStateFunction.java index 026b37cabbf1..802949c0ddb6 100644 --- a/sql/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsWithStateFunction.java +++ b/sql/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsWithStateFunction.java @@ -27,7 +27,7 @@ /** * ::Experimental:: * Base interface for a map function used in - * {@link org.apache.spark.sql.KeyValueGroupedDataset#flatMapGroupsWithState( + * {@code org.apache.spark.sql.KeyValueGroupedDataset.flatMapGroupsWithState( * FlatMapGroupsWithStateFunction, org.apache.spark.sql.streaming.OutputMode, * org.apache.spark.sql.Encoder, org.apache.spark.sql.Encoder)} * @since 2.1.1 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala index 87c562176887..022c2f5629e8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala @@ -298,7 +298,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( * For a static batch Dataset, the function will be invoked once per group. For a streaming * Dataset, the function will be invoked for each group repeatedly in every trigger, and * updates to each group's state will be saved across invocations. - * See [[GroupState]] for more details. + * See `GroupState` for more details. * * @tparam S The type of the user-defined state. Must be encodable to Spark SQL types. * @tparam U The type of the output objects. Must be encodable to Spark SQL types. @@ -328,7 +328,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( * For a static batch Dataset, the function will be invoked once per group. For a streaming * Dataset, the function will be invoked for each group repeatedly in every trigger, and * updates to each group's state will be saved across invocations. - * See [[GroupState]] for more details. + * See `GroupState` for more details. * * @tparam S The type of the user-defined state. Must be encodable to Spark SQL types. * @tparam U The type of the output objects. Must be encodable to Spark SQL types. @@ -360,7 +360,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( * For a static batch Dataset, the function will be invoked once per group. For a streaming * Dataset, the function will be invoked for each group repeatedly in every trigger, and * updates to each group's state will be saved across invocations. - * See [[GroupState]] for more details. + * See `GroupState` for more details. * * @tparam S The type of the user-defined state. Must be encodable to Spark SQL types. * @tparam U The type of the output objects. Must be encodable to Spark SQL types. @@ -400,7 +400,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( * For a static batch Dataset, the function will be invoked once per group. For a streaming * Dataset, the function will be invoked for each group repeatedly in every trigger, and * updates to each group's state will be saved across invocations. - * See [[GroupState]] for more details. + * See `GroupState` for more details. * * @tparam S The type of the user-defined state. Must be encodable to Spark SQL types. * @tparam U The type of the output objects. Must be encodable to Spark SQL types. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/GroupState.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/GroupState.scala index 60a4d0d8f98a..15df906ca7b1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/GroupState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/GroupState.scala @@ -18,18 +18,18 @@ package org.apache.spark.sql.streaming import org.apache.spark.annotation.{Experimental, InterfaceStability} -import org.apache.spark.sql.{Encoder, KeyValueGroupedDataset} +import org.apache.spark.sql.KeyValueGroupedDataset import org.apache.spark.sql.catalyst.plans.logical.LogicalGroupState /** * :: Experimental :: * * Wrapper class for interacting with per-group state data in `mapGroupsWithState` and - * `flatMapGroupsWithState` operations on [[KeyValueGroupedDataset]]. + * `flatMapGroupsWithState` operations on `KeyValueGroupedDataset`. * * Detail description on `[map/flatMap]GroupsWithState` operation * -------------------------------------------------------------- - * Both, `mapGroupsWithState` and `flatMapGroupsWithState` in [[KeyValueGroupedDataset]] + * Both, `mapGroupsWithState` and `flatMapGroupsWithState` in `KeyValueGroupedDataset` * will invoke the user-given function on each group (defined by the grouping function in * `Dataset.groupByKey()`) while maintaining user-defined per-group state between invocations. * For a static batch Dataset, the function will be invoked once per group. For a streaming @@ -70,8 +70,8 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalGroupState * `[map|flatMap]GroupsWithState`, but the exact timeout duration/timestamp is configurable per * group by calling `setTimeout...()` in `GroupState`. * - Timeouts can be either based on processing time (i.e. - * [[GroupStateTimeout.ProcessingTimeTimeout]]) or event time (i.e. - * [[GroupStateTimeout.EventTimeTimeout]]). + * `GroupStateTimeout.ProcessingTimeTimeout`) or event time (i.e. + * `GroupStateTimeout.EventTimeTimeout`). * - With `ProcessingTimeTimeout`, the timeout duration can be set by calling * `GroupState.setTimeoutDuration`. The timeout will occur when the clock has advanced by the set * duration. Guarantees provided by this timeout with a duration of D ms are as follows: @@ -177,7 +177,7 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalGroupState * }}} * * @tparam S User-defined type of the state to be stored for each group. Must be encodable into - * Spark SQL types (see [[Encoder]] for more details). + * Spark SQL types (see `Encoder` for more details). * @since 2.2.0 */ @Experimental @@ -224,7 +224,7 @@ trait GroupState[S] extends LogicalGroupState[S] { /** * Set the timeout duration for this key as a string. For example, "1 hour", "2 days", etc. * - * @note, ProcessingTimeTimeout must be enabled in `[map/flatmap]GroupsWithStates`. + * @note ProcessingTimeTimeout must be enabled in `[map/flatmap]GroupsWithStates`. */ @throws[IllegalArgumentException]("if 'duration' is not a valid duration") @throws[IllegalStateException]("when state is either not initialized, or already removed") @@ -240,7 +240,7 @@ trait GroupState[S] extends LogicalGroupState[S] { * Set the timeout timestamp for this key as milliseconds in epoch time. * This timestamp cannot be older than the current watermark. * - * @note, EventTimeTimeout must be enabled in `[map/flatmap]GroupsWithStates`. + * @note EventTimeTimeout must be enabled in `[map/flatmap]GroupsWithStates`. */ def setTimeoutTimestamp(timestampMs: Long): Unit @@ -254,7 +254,7 @@ trait GroupState[S] extends LogicalGroupState[S] { * The final timestamp (including the additional duration) cannot be older than the * current watermark. * - * @note, EventTimeTimeout must be enabled in `[map/flatmap]GroupsWithStates`. + * @note EventTimeTimeout must be enabled in `[map/flatmap]GroupsWithStates`. */ def setTimeoutTimestamp(timestampMs: Long, additionalDuration: String): Unit @@ -265,7 +265,7 @@ trait GroupState[S] extends LogicalGroupState[S] { * Set the timeout timestamp for this key as a java.sql.Date. * This timestamp cannot be older than the current watermark. * - * @note, EventTimeTimeout must be enabled in `[map/flatmap]GroupsWithStates`. + * @note EventTimeTimeout must be enabled in `[map/flatmap]GroupsWithStates`. */ def setTimeoutTimestamp(timestamp: java.sql.Date): Unit @@ -279,7 +279,7 @@ trait GroupState[S] extends LogicalGroupState[S] { * The final timestamp (including the additional duration) cannot be older than the * current watermark. * - * @note, EventTimeTimeout must be enabled in `[map/flatmap]GroupsWithStates`. + * @note EventTimeTimeout must be enabled in `[map/flatmap]GroupsWithStates`. */ def setTimeoutTimestamp(timestamp: java.sql.Date, additionalDuration: String): Unit } From b70c03a42002e924e979acbc98a8b464830be532 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Thu, 23 Mar 2017 08:42:42 +0000 Subject: [PATCH 109/512] [INFRA] Close stale PRs Closes #16819 Closes #13467 Closes #16083 Closes #17135 Closes #8785 Closes #16278 Closes #16997 Closes #17073 Closes #17220 Added: Closes #12059 Closes #12524 Closes #12888 Closes #16061 Author: Sean Owen Closes #17386 from srowen/StalePRs. From b0ae6a38a3ef65e4e853781c5127ba38997a8546 Mon Sep 17 00:00:00 2001 From: Ye Yin Date: Thu, 23 Mar 2017 13:30:50 +0100 Subject: [PATCH 110/512] Typo fixup in comment ## What changes were proposed in this pull request? Fixup typo in comment. ## How was this patch tested? Don't need. Author: Ye Yin Closes #17396 from hustcat/fix. --- .../spark/scheduler/cluster/mesos/MesosClusterManager.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterManager.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterManager.scala index ed29b346ba26..911a0857917e 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterManager.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterManager.scala @@ -22,7 +22,7 @@ import org.apache.spark.internal.config._ import org.apache.spark.scheduler.{ExternalClusterManager, SchedulerBackend, TaskScheduler, TaskSchedulerImpl} /** - * Cluster Manager for creation of Yarn scheduler and backend + * Cluster Manager for creation of Mesos scheduler and backend */ private[spark] class MesosClusterManager extends ExternalClusterManager { private val MESOS_REGEX = """mesos://(.*)""".r From 746a558de2136f91f8fe77c6e51256017aa50913 Mon Sep 17 00:00:00 2001 From: Tyson Condie Date: Thu, 23 Mar 2017 14:32:05 -0700 Subject: [PATCH 111/512] [SPARK-19876][SS][WIP] OneTime Trigger Executor ## What changes were proposed in this pull request? An additional trigger and trigger executor that will execute a single trigger only. One can use this OneTime trigger to have more control over the scheduling of triggers. In addition, this patch requires an optimization to StreamExecution that logs a commit record at the end of successfully processing a batch. This new commit log will be used to determine the next batch (offsets) to process after a restart, instead of using the offset log itself to determine what batch to process next after restart; using the offset log to determine this would process the previously logged batch, always, thus not permitting a OneTime trigger feature. ## How was this patch tested? A number of existing tests have been revised. These tests all assumed that when restarting a stream, the last batch in the offset log is to be re-processed. Given that we now have a commit log that will tell us if that last batch was processed successfully, the results/assumptions of those tests needed to be revised accordingly. In addition, a OneTime trigger test was added to StreamingQuerySuite, which tests: - The semantics of OneTime trigger (i.e., on start, execute a single batch, then stop). - The case when the commit log was not able to successfully log the completion of a batch before restart, which would mean that we should fall back to what's in the offset log. - A OneTime trigger execution that results in an exception being thrown. marmbrus tdas zsxwing Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Tyson Condie Author: Tathagata Das Closes #17219 from tcondie/stream-commit. --- .../spark/sql/kafka010/KafkaSourceSuite.scala | 2 - project/MimaExcludes.scala | 6 +- python/pyspark/sql/streaming.py | 63 +++-------- python/pyspark/sql/tests.py | 17 ++- .../execution/streaming/BatchCommitLog.scala | 77 +++++++++++++ .../execution/streaming/StreamExecution.scala | 81 +++++++++++--- .../execution/streaming/TriggerExecutor.scala | 11 ++ .../sql/execution/streaming/Triggers.scala | 29 +++++ .../sql/streaming/DataStreamWriter.scala | 2 +- .../{Trigger.scala => ProcessingTime.scala} | 36 +++--- .../apache/spark/sql/streaming/Trigger.java | 105 ++++++++++++++++++ .../streaming/EventTimeWatermarkSuite.scala | 4 +- .../FlatMapGroupsWithStateSuite.scala | 3 +- .../spark/sql/streaming/StreamSuite.scala | 20 +++- .../spark/sql/streaming/StreamTest.scala | 2 +- .../streaming/StreamingAggregationSuite.scala | 4 + .../StreamingQueryListenerSuite.scala | 18 ++- .../sql/streaming/StreamingQuerySuite.scala | 48 +++++++- .../test/DataStreamReaderWriterSuite.scala | 5 +- 19 files changed, 439 insertions(+), 94 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/BatchCommitLog.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Triggers.scala rename sql/core/src/main/scala/org/apache/spark/sql/streaming/{Trigger.scala => ProcessingTime.scala} (74%) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/streaming/Trigger.java diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala index 7b6396e0291c..6391d6269c5a 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala @@ -301,8 +301,6 @@ class KafkaSourceSuite extends KafkaSourceTest { StopStream, StartStream(ProcessingTime(100), clock), waitUntilBatchProcessed, - AdvanceManualClock(100), - waitUntilBatchProcessed, // smallest now empty, 1 more from middle, 9 more from biggest CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107, 11, 108, 109, 110, 111, 112, 113, 114, 115, 116, diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index bd4528bd2126..9925a8ba7266 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -64,7 +64,11 @@ object MimaExcludes { ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.TaskData.$default$11"), // [SPARK-17161] Removing Python-friendly constructors not needed - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.OneVsRestModel.this") + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.OneVsRestModel.this"), + + // [SPARK-19876] Add one time trigger, and improve Trigger APIs + ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.sql.streaming.Trigger"), + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.sql.streaming.ProcessingTime") ) // Exclude rules for 2.1.x diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index 80f4340cdf13..27d6725615a4 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -277,44 +277,6 @@ def resetTerminated(self): self._jsqm.resetTerminated() -class Trigger(object): - """Used to indicate how often results should be produced by a :class:`StreamingQuery`. - - .. note:: Experimental - - .. versionadded:: 2.0 - """ - - __metaclass__ = ABCMeta - - @abstractmethod - def _to_java_trigger(self, sqlContext): - """Internal method to construct the trigger on the jvm. - """ - pass - - -class ProcessingTime(Trigger): - """A trigger that runs a query periodically based on the processing time. If `interval` is 0, - the query will run as fast as possible. - - The interval should be given as a string, e.g. '2 seconds', '5 minutes', ... - - .. note:: Experimental - - .. versionadded:: 2.0 - """ - - def __init__(self, interval): - if type(interval) != str or len(interval.strip()) == 0: - raise ValueError("interval should be a non empty interval string, e.g. '2 seconds'.") - self.interval = interval - - def _to_java_trigger(self, sqlContext): - return sqlContext._sc._jvm.org.apache.spark.sql.streaming.ProcessingTime.create( - self.interval) - - class DataStreamReader(OptionUtils): """ Interface used to load a streaming :class:`DataFrame` from external storage systems @@ -790,7 +752,7 @@ def queryName(self, queryName): @keyword_only @since(2.0) - def trigger(self, processingTime=None): + def trigger(self, processingTime=None, once=None): """Set the trigger for the stream query. If this is not set it will run the query as fast as possible, which is equivalent to setting the trigger to ``processingTime='0 seconds'``. @@ -800,17 +762,26 @@ def trigger(self, processingTime=None): >>> # trigger the query for execution every 5 seconds >>> writer = sdf.writeStream.trigger(processingTime='5 seconds') + >>> # trigger the query for just once batch of data + >>> writer = sdf.writeStream.trigger(once=True) """ - from pyspark.sql.streaming import ProcessingTime - trigger = None + jTrigger = None if processingTime is not None: + if once is not None: + raise ValueError('Multiple triggers not allowed.') if type(processingTime) != str or len(processingTime.strip()) == 0: - raise ValueError('The processing time must be a non empty string. Got: %s' % + raise ValueError('Value for processingTime must be a non empty string. Got: %s' % processingTime) - trigger = ProcessingTime(processingTime) - if trigger is None: - raise ValueError('A trigger was not provided. Supported triggers: processingTime.') - self._jwrite = self._jwrite.trigger(trigger._to_java_trigger(self._spark)) + interval = processingTime.strip() + jTrigger = self._spark._sc._jvm.org.apache.spark.sql.streaming.Trigger.ProcessingTime( + interval) + elif once is not None: + if once is not True: + raise ValueError('Value for once must be True. Got: %s' % once) + jTrigger = self._spark._sc._jvm.org.apache.spark.sql.streaming.Trigger.Once() + else: + raise ValueError('No trigger provided') + self._jwrite = self._jwrite.trigger(jTrigger) return self @ignore_unicode_prefix diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 29d613bc5fe3..b93b7ed19210 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -1255,13 +1255,26 @@ def test_save_and_load_builder(self): shutil.rmtree(tmpPath) - def test_stream_trigger_takes_keyword_args(self): + def test_stream_trigger(self): df = self.spark.readStream.format('text').load('python/test_support/sql/streaming') + + # Should take at least one arg + try: + df.writeStream.trigger() + except ValueError: + pass + + # Should not take multiple args + try: + df.writeStream.trigger(once=True, processingTime='5 seconds') + except ValueError: + pass + + # Should take only keyword args try: df.writeStream.trigger('5 seconds') self.fail("Should have thrown an exception") except TypeError: - # should throw error pass def test_stream_read_options(self): diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/BatchCommitLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/BatchCommitLog.scala new file mode 100644 index 000000000000..fb1a4fb9b12f --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/BatchCommitLog.scala @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import java.io.{InputStream, OutputStream} +import java.nio.charset.StandardCharsets._ + +import scala.io.{Source => IOSource} + +import org.apache.spark.sql.SparkSession + +/** + * Used to write log files that represent batch commit points in structured streaming. + * A commit log file will be written immediately after the successful completion of a + * batch, and before processing the next batch. Here is an execution summary: + * - trigger batch 1 + * - obtain batch 1 offsets and write to offset log + * - process batch 1 + * - write batch 1 to completion log + * - trigger batch 2 + * - obtain bactch 2 offsets and write to offset log + * - process batch 2 + * - write batch 2 to completion log + * .... + * + * The current format of the batch completion log is: + * line 1: version + * line 2: metadata (optional json string) + */ +class BatchCommitLog(sparkSession: SparkSession, path: String) + extends HDFSMetadataLog[String](sparkSession, path) { + + override protected def deserialize(in: InputStream): String = { + // called inside a try-finally where the underlying stream is closed in the caller + val lines = IOSource.fromInputStream(in, UTF_8.name()).getLines() + if (!lines.hasNext) { + throw new IllegalStateException("Incomplete log file in the offset commit log") + } + parseVersion(lines.next().trim, BatchCommitLog.VERSION) + // read metadata + lines.next().trim match { + case BatchCommitLog.SERIALIZED_VOID => null + case metadata => metadata + } + } + + override protected def serialize(metadata: String, out: OutputStream): Unit = { + // called inside a try-finally where the underlying stream is closed in the caller + out.write(s"v${BatchCommitLog.VERSION}".getBytes(UTF_8)) + out.write('\n') + + // write metadata or void + out.write((if (metadata == null) BatchCommitLog.SERIALIZED_VOID else metadata) + .getBytes(UTF_8)) + } +} + +object BatchCommitLog { + private val VERSION = 1 + private val SERIALIZED_VOID = "{}" +} + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index 60d5283e6b21..34e9262af7cb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -165,6 +165,8 @@ class StreamExecution( private val triggerExecutor = trigger match { case t: ProcessingTime => ProcessingTimeExecutor(t, triggerClock) + case OneTimeTrigger => OneTimeExecutor() + case _ => throw new IllegalStateException(s"Unknown type of trigger: $trigger") } /** Defines the internal state of execution */ @@ -209,6 +211,13 @@ class StreamExecution( */ val offsetLog = new OffsetSeqLog(sparkSession, checkpointFile("offsets")) + /** + * A log that records the batch ids that have completed. This is used to check if a batch was + * fully processed, and its output was committed to the sink, hence no need to process it again. + * This is used (for instance) during restart, to help identify which batch to run next. + */ + val batchCommitLog = new BatchCommitLog(sparkSession, checkpointFile("commits")) + /** Whether all fields of the query have been initialized */ private def isInitialized: Boolean = state.get != INITIALIZING @@ -291,10 +300,13 @@ class StreamExecution( runBatch(sparkSessionToRunBatches) } } - // Report trigger as finished and construct progress object. finishTrigger(dataAvailable) if (dataAvailable) { + // Update committed offsets. + committedOffsets ++= availableOffsets + batchCommitLog.add(currentBatchId, null) + logDebug(s"batch ${currentBatchId} committed") // We'll increase currentBatchId after we complete processing current batch's data currentBatchId += 1 } else { @@ -306,9 +318,6 @@ class StreamExecution( } else { false } - - // Update committed offsets. - committedOffsets ++= availableOffsets updateStatusMessage("Waiting for next trigger") continueToRun }) @@ -392,13 +401,33 @@ class StreamExecution( * - currentBatchId * - committedOffsets * - availableOffsets + * The basic structure of this method is as follows: + * + * Identify (from the offset log) the offsets used to run the last batch + * IF last batch exists THEN + * Set the next batch to be executed as the last recovered batch + * Check the commit log to see which batch was committed last + * IF the last batch was committed THEN + * Call getBatch using the last batch start and end offsets + * // ^^^^ above line is needed since some sources assume last batch always re-executes + * Setup for a new batch i.e., start = last batch end, and identify new end + * DONE + * ELSE + * Identify a brand new batch + * DONE */ private def populateStartOffsets(sparkSessionToRunBatches: SparkSession): Unit = { offsetLog.getLatest() match { - case Some((batchId, nextOffsets)) => - logInfo(s"Resuming streaming query, starting with batch $batchId") - currentBatchId = batchId + case Some((latestBatchId, nextOffsets)) => + /* First assume that we are re-executing the latest known batch + * in the offset log */ + currentBatchId = latestBatchId availableOffsets = nextOffsets.toStreamProgress(sources) + /* Initialize committed offsets to a committed batch, which at this + * is the second latest batch id in the offset log. */ + offsetLog.get(latestBatchId - 1).foreach { secondLatestBatchId => + committedOffsets = secondLatestBatchId.toStreamProgress(sources) + } // update offset metadata nextOffsets.metadata.foreach { metadata => @@ -419,14 +448,37 @@ class StreamExecution( SQLConf.SHUFFLE_PARTITIONS.key, shufflePartitionsToUse.toString) } - logDebug(s"Found possibly unprocessed offsets $availableOffsets " + - s"at batch timestamp ${offsetSeqMetadata.batchTimestampMs}") - - offsetLog.get(batchId - 1).foreach { - case lastOffsets => - committedOffsets = lastOffsets.toStreamProgress(sources) - logDebug(s"Resuming with committed offsets: $committedOffsets") + /* identify the current batch id: if commit log indicates we successfully processed the + * latest batch id in the offset log, then we can safely move to the next batch + * i.e., committedBatchId + 1 */ + batchCommitLog.getLatest() match { + case Some((latestCommittedBatchId, _)) => + if (latestBatchId == latestCommittedBatchId) { + /* The last batch was successfully committed, so we can safely process a + * new next batch but first: + * Make a call to getBatch using the offsets from previous batch. + * because certain sources (e.g., KafkaSource) assume on restart the last + * batch will be executed before getOffset is called again. */ + availableOffsets.foreach { ao: (Source, Offset) => + val (source, end) = ao + if (committedOffsets.get(source).map(_ != end).getOrElse(true)) { + val start = committedOffsets.get(source) + source.getBatch(start, end) + } + } + currentBatchId = latestCommittedBatchId + 1 + committedOffsets ++= availableOffsets + // Construct a new batch be recomputing availableOffsets + constructNextBatch() + } else if (latestCommittedBatchId < latestBatchId - 1) { + logWarning(s"Batch completion log latest batch id is " + + s"${latestCommittedBatchId}, which is not trailing " + + s"batchid $latestBatchId by one") + } + case None => logInfo("no commit log present") } + logDebug(s"Resuming at batch $currentBatchId with committed offsets " + + s"$committedOffsets and available offsets $availableOffsets") case None => // We are starting this stream for the first time. logInfo(s"Starting new streaming query.") currentBatchId = 0 @@ -523,6 +575,7 @@ class StreamExecution( // Note that purge is exclusive, i.e. it purges everything before the target ID. if (minBatchesToRetain < currentBatchId) { offsetLog.purge(currentBatchId - minBatchesToRetain) + batchCommitLog.purge(currentBatchId - minBatchesToRetain) } } } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TriggerExecutor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TriggerExecutor.scala index ac510df209f0..02996ac854f6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TriggerExecutor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TriggerExecutor.scala @@ -29,6 +29,17 @@ trait TriggerExecutor { def execute(batchRunner: () => Boolean): Unit } +/** + * A trigger executor that runs a single batch only, then terminates. + */ +case class OneTimeExecutor() extends TriggerExecutor { + + /** + * Execute a single batch using `batchRunner`. + */ + override def execute(batchRunner: () => Boolean): Unit = batchRunner() +} + /** * A trigger executor that runs a batch every `intervalMs` milliseconds. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Triggers.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Triggers.scala new file mode 100644 index 000000000000..271bc4da99c0 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Triggers.scala @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import org.apache.spark.annotation.{Experimental, InterfaceStability} +import org.apache.spark.sql.streaming.Trigger + +/** + * A [[Trigger]] that process only one batch of data in a streaming query then terminates + * the query. + */ +@Experimental +@InterfaceStability.Evolving +case object OneTimeTrigger extends Trigger diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala index fe52013badb6..f2f700590ca8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala @@ -377,7 +377,7 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { private var outputMode: OutputMode = OutputMode.Append - private var trigger: Trigger = ProcessingTime(0L) + private var trigger: Trigger = Trigger.ProcessingTime(0L) private var extraOptions = new scala.collection.mutable.HashMap[String, String] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/Trigger.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/ProcessingTime.scala similarity index 74% rename from sql/core/src/main/scala/org/apache/spark/sql/streaming/Trigger.scala rename to sql/core/src/main/scala/org/apache/spark/sql/streaming/ProcessingTime.scala index 68f2eab9d45f..bdad8e4717be 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/Trigger.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/ProcessingTime.scala @@ -26,16 +26,6 @@ import org.apache.commons.lang3.StringUtils import org.apache.spark.annotation.{Experimental, InterfaceStability} import org.apache.spark.unsafe.types.CalendarInterval -/** - * :: Experimental :: - * Used to indicate how often results should be produced by a [[StreamingQuery]]. - * - * @since 2.0.0 - */ -@Experimental -@InterfaceStability.Evolving -sealed trait Trigger - /** * :: Experimental :: * A trigger that runs a query periodically based on the processing time. If `interval` is 0, @@ -43,24 +33,25 @@ sealed trait Trigger * * Scala Example: * {{{ - * df.write.trigger(ProcessingTime("10 seconds")) + * df.writeStream.trigger(ProcessingTime("10 seconds")) * * import scala.concurrent.duration._ - * df.write.trigger(ProcessingTime(10.seconds)) + * df.writeStream.trigger(ProcessingTime(10.seconds)) * }}} * * Java Example: * {{{ - * df.write.trigger(ProcessingTime.create("10 seconds")) + * df.writeStream.trigger(ProcessingTime.create("10 seconds")) * * import java.util.concurrent.TimeUnit - * df.write.trigger(ProcessingTime.create(10, TimeUnit.SECONDS)) + * df.writeStream.trigger(ProcessingTime.create(10, TimeUnit.SECONDS)) * }}} * * @since 2.0.0 */ @Experimental @InterfaceStability.Evolving +@deprecated("use Trigger.ProcessingTimeTrigger(intervalMs)", "2.2.0") case class ProcessingTime(intervalMs: Long) extends Trigger { require(intervalMs >= 0, "the interval of trigger should not be negative") } @@ -73,6 +64,7 @@ case class ProcessingTime(intervalMs: Long) extends Trigger { */ @Experimental @InterfaceStability.Evolving +@deprecated("use Trigger.ProcessingTimeTrigger(intervalMs)", "2.2.0") object ProcessingTime { /** @@ -80,11 +72,13 @@ object ProcessingTime { * * Example: * {{{ - * df.write.trigger(ProcessingTime("10 seconds")) + * df.writeStream.trigger(ProcessingTime("10 seconds")) * }}} * * @since 2.0.0 + * @deprecated use Trigger.ProcessingTimeTrigger(interval) */ + @deprecated("use Trigger.ProcessingTimeTrigger(interval)", "2.2.0") def apply(interval: String): ProcessingTime = { if (StringUtils.isBlank(interval)) { throw new IllegalArgumentException( @@ -110,11 +104,13 @@ object ProcessingTime { * Example: * {{{ * import scala.concurrent.duration._ - * df.write.trigger(ProcessingTime(10.seconds)) + * df.writeStream.trigger(ProcessingTime(10.seconds)) * }}} * * @since 2.0.0 + * @deprecated use Trigger.ProcessingTimeTrigger(interval) */ + @deprecated("use Trigger.ProcessingTimeTrigger(interval)", "2.2.0") def apply(interval: Duration): ProcessingTime = { new ProcessingTime(interval.toMillis) } @@ -124,11 +120,13 @@ object ProcessingTime { * * Example: * {{{ - * df.write.trigger(ProcessingTime.create("10 seconds")) + * df.writeStream.trigger(ProcessingTime.create("10 seconds")) * }}} * * @since 2.0.0 + * @deprecated use Trigger.ProcessingTimeTrigger(interval) */ + @deprecated("use Trigger.ProcessingTimeTrigger(interval)", "2.2.0") def create(interval: String): ProcessingTime = { apply(interval) } @@ -139,11 +137,13 @@ object ProcessingTime { * Example: * {{{ * import java.util.concurrent.TimeUnit - * df.write.trigger(ProcessingTime.create(10, TimeUnit.SECONDS)) + * df.writeStream.trigger(ProcessingTime.create(10, TimeUnit.SECONDS)) * }}} * * @since 2.0.0 + * @deprecated use Trigger.ProcessingTimeTrigger(interval) */ + @deprecated("use Trigger.ProcessingTimeTrigger(interval, unit)", "2.2.0") def create(interval: Long, unit: TimeUnit): ProcessingTime = { new ProcessingTime(unit.toMillis(interval)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/Trigger.java b/sql/core/src/main/scala/org/apache/spark/sql/streaming/Trigger.java new file mode 100644 index 000000000000..a03a851f245f --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/Trigger.java @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming; + +import java.util.concurrent.TimeUnit; + +import scala.concurrent.duration.Duration; + +import org.apache.spark.annotation.Experimental; +import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.execution.streaming.OneTimeTrigger$; + +/** + * :: Experimental :: + * Policy used to indicate how often results should be produced by a [[StreamingQuery]]. + * + * @since 2.0.0 + */ +@Experimental +@InterfaceStability.Evolving +public class Trigger { + + /** + * :: Experimental :: + * A trigger policy that runs a query periodically based on an interval in processing time. + * If `interval` is 0, the query will run as fast as possible. + * + * @since 2.2.0 + */ + public static Trigger ProcessingTime(long intervalMs) { + return ProcessingTime.apply(intervalMs); + } + + /** + * :: Experimental :: + * (Java-friendly) + * A trigger policy that runs a query periodically based on an interval in processing time. + * If `interval` is 0, the query will run as fast as possible. + * + * {{{ + * import java.util.concurrent.TimeUnit + * df.writeStream.trigger(ProcessingTime.create(10, TimeUnit.SECONDS)) + * }}} + * + * @since 2.2.0 + */ + public static Trigger ProcessingTime(long interval, TimeUnit timeUnit) { + return ProcessingTime.create(interval, timeUnit); + } + + /** + * :: Experimental :: + * (Scala-friendly) + * A trigger policy that runs a query periodically based on an interval in processing time. + * If `duration` is 0, the query will run as fast as possible. + * + * {{{ + * import scala.concurrent.duration._ + * df.writeStream.trigger(ProcessingTime(10.seconds)) + * }}} + * @since 2.2.0 + */ + public static Trigger ProcessingTime(Duration interval) { + return ProcessingTime.apply(interval); + } + + /** + * :: Experimental :: + * A trigger policy that runs a query periodically based on an interval in processing time. + * If `interval` is effectively 0, the query will run as fast as possible. + * + * {{{ + * df.writeStream.trigger(Trigger.ProcessingTime("10 seconds")) + * }}} + * @since 2.2.0 + */ + public static Trigger ProcessingTime(String interval) { + return ProcessingTime.apply(interval); + } + + /** + * A trigger that process only one batch of data in a streaming query then terminates + * the query. + * + * @since 2.2.0 + */ + public static Trigger Once() { + return OneTimeTrigger$.MODULE$; + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala index 7614ea5eb3c0..fd850a7365e2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala @@ -218,7 +218,9 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Loggin AddData(inputData, 25), // Evict items less than previous watermark. CheckLastBatch((10, 5)), StopStream, - AssertOnQuery { q => // clear the sink + AssertOnQuery { q => // purge commit and clear the sink + val commit = q.batchCommitLog.getLatest().map(_._1).getOrElse(-1L) + 1L + q.batchCommitLog.purge(commit) q.sink.asInstanceOf[MemorySink].clear() true }, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala index 89a25973afdd..a00a1a582a97 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala @@ -575,9 +575,10 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf StopStream, StartStream(ProcessingTime("1 second"), triggerClock = clock), + AdvanceManualClock(10 * 1000), AddData(inputData, "c"), - AdvanceManualClock(20 * 1000), + AdvanceManualClock(1 * 1000), CheckLastBatch(("b", "-1"), ("c", "1")), assertNumStateRows(total = 1, updated = 2), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala index f01211e20cbf..32920f6dfa22 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala @@ -156,6 +156,15 @@ class StreamSuite extends StreamTest { AssertOnQuery(_.offsetLog.getLatest().get._1 == expectedId, s"offsetLog's latest should be $expectedId") + // Check the latest batchid in the commit log + def CheckCommitLogLatestBatchId(expectedId: Int): AssertOnQuery = + AssertOnQuery(_.batchCommitLog.getLatest().get._1 == expectedId, + s"commitLog's latest should be $expectedId") + + // Ensure that there has not been an incremental execution after restart + def CheckNoIncrementalExecutionCurrentBatchId(): AssertOnQuery = + AssertOnQuery(_.lastExecution == null, s"lastExecution not expected to run") + // For each batch, we would log the state change during the execution // This checks whether the key of the state change log is the expected batch id def CheckIncrementalExecutionCurrentBatchId(expectedId: Int): AssertOnQuery = @@ -181,6 +190,7 @@ class StreamSuite extends StreamTest { // Check the results of batch 0 CheckAnswer(1, 2, 3), CheckIncrementalExecutionCurrentBatchId(0), + CheckCommitLogLatestBatchId(0), CheckOffsetLogLatestBatchId(0), CheckSinkLatestBatchId(0), // Add some data in batch 1 @@ -191,6 +201,7 @@ class StreamSuite extends StreamTest { // Check the results of batch 1 CheckAnswer(1, 2, 3, 4, 5, 6), CheckIncrementalExecutionCurrentBatchId(1), + CheckCommitLogLatestBatchId(1), CheckOffsetLogLatestBatchId(1), CheckSinkLatestBatchId(1), @@ -203,6 +214,7 @@ class StreamSuite extends StreamTest { // the currentId does not get logged (e.g. as 2) even if the clock has advanced many times CheckAnswer(1, 2, 3, 4, 5, 6), CheckIncrementalExecutionCurrentBatchId(1), + CheckCommitLogLatestBatchId(1), CheckOffsetLogLatestBatchId(1), CheckSinkLatestBatchId(1), @@ -210,14 +222,15 @@ class StreamSuite extends StreamTest { StopStream, StartStream(ProcessingTime("10 seconds"), new StreamManualClock(60 * 1000)), - /* -- batch 1 rerun ----------------- */ - // this batch 1 would re-run because the latest batch id logged in offset log is 1 + /* -- batch 1 no rerun ----------------- */ + // batch 1 would not re-run because the latest batch id logged in commit log is 1 AdvanceManualClock(10 * 1000), + CheckNoIncrementalExecutionCurrentBatchId(), /* -- batch 2 ----------------------- */ // Check the results of batch 1 CheckAnswer(1, 2, 3, 4, 5, 6), - CheckIncrementalExecutionCurrentBatchId(1), + CheckCommitLogLatestBatchId(1), CheckOffsetLogLatestBatchId(1), CheckSinkLatestBatchId(1), // Add some data in batch 2 @@ -228,6 +241,7 @@ class StreamSuite extends StreamTest { // Check the results of batch 2 CheckAnswer(1, 2, 3, 4, 5, 6, 7, 8, 9), CheckIncrementalExecutionCurrentBatchId(2), + CheckCommitLogLatestBatchId(2), CheckOffsetLogLatestBatchId(2), CheckSinkLatestBatchId(2)) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index 60e2375a9817..8cf179133681 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -159,7 +159,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with Timeouts { /** Starts the stream, resuming if data has already been processed. It must not be running. */ case class StartStream( - trigger: Trigger = ProcessingTime(0), + trigger: Trigger = Trigger.ProcessingTime(0), triggerClock: Clock = new SystemClock, additionalConfs: Map[String, String] = Map.empty) extends StreamAction diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala index 0c8015672bab..600c039cd0b9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala @@ -272,11 +272,13 @@ class StreamingAggregationSuite extends StateStoreMetricsTest with BeforeAndAfte StopStream, AssertOnQuery { q => // clear the sink q.sink.asInstanceOf[MemorySink].clear() + q.batchCommitLog.purge(3) // advance by a minute i.e., 90 seconds total clock.advance(60 * 1000L) true }, StartStream(ProcessingTime("10 seconds"), triggerClock = clock), + // The commit log blown, causing the last batch to re-run CheckLastBatch((20L, 1), (85L, 1)), AssertOnQuery { q => clock.getTimeMillis() == 90000L @@ -322,11 +324,13 @@ class StreamingAggregationSuite extends StateStoreMetricsTest with BeforeAndAfte StopStream, AssertOnQuery { q => // clear the sink q.sink.asInstanceOf[MemorySink].clear() + q.batchCommitLog.purge(3) // advance by 60 days i.e., 90 days total clock.advance(DateTimeUtils.MILLIS_PER_DAY * 60) true }, StartStream(ProcessingTime("10 day"), triggerClock = clock), + // Commit log blown, causing a re-run of the last batch CheckLastBatch((20L, 1), (85L, 1)), // advance clock to 100 days, should retain keys >= 90 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala index eb09b9ffcfc5..03dad8a6ddbc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala @@ -57,6 +57,20 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter { val inputData = new MemoryStream[Int](0, sqlContext) val df = inputData.toDS().as[Long].map { 10 / _ } val listener = new EventCollector + + case class AssertStreamExecThreadToWaitForClock() + extends AssertOnQuery(q => { + eventually(Timeout(streamingTimeout)) { + if (q.exception.isEmpty) { + assert(clock.asInstanceOf[StreamManualClock].isStreamWaitingAt(clock.getTimeMillis)) + } + } + if (q.exception.isDefined) { + throw q.exception.get + } + true + }, "") + try { // No events until started spark.streams.addListener(listener) @@ -81,6 +95,7 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter { // Progress event generated when data processed AddData(inputData, 1, 2), AdvanceManualClock(100), + AssertStreamExecThreadToWaitForClock(), CheckAnswer(10, 5), AssertOnQuery { query => assert(listener.progressEvents.nonEmpty) @@ -109,8 +124,9 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter { // Termination event generated with exception message when stopped with error StartStream(ProcessingTime(100), triggerClock = clock), + AssertStreamExecThreadToWaitForClock(), AddData(inputData, 0), - AdvanceManualClock(100), + AdvanceManualClock(100), // process bad data ExpectFailure[SparkException](), AssertOnQuery { query => eventually(Timeout(streamingTimeout)) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala index a0a2b2b4c9b3..3f41ecdb7ff6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala @@ -158,6 +158,49 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi ) } + testQuietly("OneTime trigger, commit log, and exception") { + import Trigger.Once + val inputData = MemoryStream[Int] + val mapped = inputData.toDS().map { 6 / _} + + testStream(mapped)( + AssertOnQuery(_.isActive === true), + StopStream, + AddData(inputData, 1, 2), + StartStream(trigger = Once), + CheckAnswer(6, 3), + StopStream, // clears out StreamTest state + AssertOnQuery { q => + // both commit log and offset log contain the same (latest) batch id + q.batchCommitLog.getLatest().map(_._1).getOrElse(-1L) == + q.offsetLog.getLatest().map(_._1).getOrElse(-2L) + }, + AssertOnQuery { q => + // blow away commit log and sink result + q.batchCommitLog.purge(1) + q.sink.asInstanceOf[MemorySink].clear() + true + }, + StartStream(trigger = Once), + CheckAnswer(6, 3), // ensure we fall back to offset log and reprocess batch + StopStream, + AddData(inputData, 3), + StartStream(trigger = Once), + CheckLastBatch(2), // commit log should be back in place + StopStream, + AddData(inputData, 0), + StartStream(trigger = Once), + ExpectFailure[SparkException](), + AssertOnQuery(_.isActive === false), + AssertOnQuery(q => { + q.exception.get.startOffset === + q.committedOffsets.toOffsetSeq(Seq(inputData), OffsetSeqMetadata()).toString && + q.exception.get.endOffset === + q.availableOffsets.toOffsetSeq(Seq(inputData), OffsetSeqMetadata()).toString + }, "incorrect start offset or end offset on exception") + ) + } + testQuietly("status, lastProgress, and recentProgress") { import StreamingQuerySuite._ clock = new StreamManualClock @@ -237,6 +280,7 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi AdvanceManualClock(500), // time = 1100 to unblock job AssertOnQuery { _ => clock.getTimeMillis() === 1100 }, CheckAnswer(2), + AssertStreamExecThreadToWaitForClock(), AssertOnQuery(_.status.isDataAvailable === true), AssertOnQuery(_.status.isTriggerActive === false), AssertOnQuery(_.status.message === "Waiting for next trigger"), @@ -275,6 +319,7 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi AddData(inputData, 1, 2), AdvanceManualClock(100), // allow another trigger + AssertStreamExecThreadToWaitForClock(), CheckAnswer(4), AssertOnQuery(_.status.isDataAvailable === true), AssertOnQuery(_.status.isTriggerActive === false), @@ -306,8 +351,9 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi // Test status and progress after query terminated with error StartStream(ProcessingTime(100), triggerClock = clock), + AdvanceManualClock(100), // ensure initial trigger completes before AddData AddData(inputData, 0), - AdvanceManualClock(100), + AdvanceManualClock(100), // allow another trigger ExpectFailure[SparkException](), AssertOnQuery(_.status.isDataAvailable === false), AssertOnQuery(_.status.isTriggerActive === false), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala index 341ab0eb923d..05cd3d9f7c2f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala @@ -31,7 +31,8 @@ import org.apache.spark.sql._ import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.{StreamSinkProvider, StreamSourceProvider} -import org.apache.spark.sql.streaming._ +import org.apache.spark.sql.streaming.{ProcessingTime => DeprecatedProcessingTime, _} +import org.apache.spark.sql.streaming.Trigger._ import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -346,7 +347,7 @@ class DataStreamReaderWriterSuite extends StreamTest with BeforeAndAfter { q = df.writeStream .format("org.apache.spark.sql.streaming.test") .option("checkpointLocation", newMetadataDir) - .trigger(ProcessingTime.create(100, TimeUnit.SECONDS)) + .trigger(ProcessingTime(100, TimeUnit.SECONDS)) .start() q.stop() From b7be05a203b3e2a307147ea0c6cb0dec03da82a2 Mon Sep 17 00:00:00 2001 From: erenavsarogullari Date: Thu, 23 Mar 2017 17:20:52 -0700 Subject: [PATCH 112/512] [SPARK-19567][CORE][SCHEDULER] Support some Schedulable variables immutability and access ## What changes were proposed in this pull request? Some `Schedulable` Entities(`Pool` and `TaskSetManager`) variables need refactoring for _immutability_ and _access modifiers_ levels as follows: - From `var` to `val` (if there is no requirement): This is important to support immutability as much as possible. - Sample => `Pool`: `weight`, `minShare`, `priority`, `name` and `taskSetSchedulingAlgorithm`. - Access modifiers: Specially, `var`s access needs to be restricted from other parts of codebase to prevent potential side effects. - `TaskSetManager`: `tasksSuccessful`, `totalResultSize`, `calculatedTasks` etc... This PR is related with #15604 and has been created seperatedly to keep patch content as isolated and to help the reviewers. ## How was this patch tested? Added new UTs and existing UT coverage. Author: erenavsarogullari Closes #16905 from erenavsarogullari/SPARK-19567. --- .../org/apache/spark/scheduler/Pool.scala | 12 +++---- .../spark/scheduler/TaskSchedulerImpl.scala | 19 ++++++---- .../spark/scheduler/TaskSetManager.scala | 36 ++++++++++--------- .../spark/scheduler/DAGSchedulerSuite.scala | 8 ++--- .../ExternalClusterManagerSuite.scala | 4 +-- .../apache/spark/scheduler/PoolSuite.scala | 6 ++++ .../scheduler/TaskSchedulerImplSuite.scala | 12 +++++-- 7 files changed, 58 insertions(+), 39 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/Pool.scala b/core/src/main/scala/org/apache/spark/scheduler/Pool.scala index 2a69a6c5e879..1181371ab425 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Pool.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Pool.scala @@ -37,24 +37,24 @@ private[spark] class Pool( val schedulableQueue = new ConcurrentLinkedQueue[Schedulable] val schedulableNameToSchedulable = new ConcurrentHashMap[String, Schedulable] - var weight = initWeight - var minShare = initMinShare + val weight = initWeight + val minShare = initMinShare var runningTasks = 0 - var priority = 0 + val priority = 0 // A pool's stage id is used to break the tie in scheduling. var stageId = -1 - var name = poolName + val name = poolName var parent: Pool = null - var taskSetSchedulingAlgorithm: SchedulingAlgorithm = { + private val taskSetSchedulingAlgorithm: SchedulingAlgorithm = { schedulingMode match { case SchedulingMode.FAIR => new FairSchedulingAlgorithm() case SchedulingMode.FIFO => new FIFOSchedulingAlgorithm() case _ => - val msg = "Unsupported scheduling mode: $schedulingMode. Use FAIR or FIFO instead." + val msg = s"Unsupported scheduling mode: $schedulingMode. Use FAIR or FIFO instead." throw new IllegalArgumentException(msg) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index bfbcfa1aa386..8257c70d672a 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -59,6 +59,8 @@ private[spark] class TaskSchedulerImpl private[scheduler]( extends TaskScheduler with Logging { + import TaskSchedulerImpl._ + def this(sc: SparkContext) = { this( sc, @@ -130,17 +132,18 @@ private[spark] class TaskSchedulerImpl private[scheduler]( val mapOutputTracker = SparkEnv.get.mapOutputTracker - var schedulableBuilder: SchedulableBuilder = null - var rootPool: Pool = null + private var schedulableBuilder: SchedulableBuilder = null // default scheduler is FIFO - private val schedulingModeConf = conf.get("spark.scheduler.mode", "FIFO") + private val schedulingModeConf = conf.get(SCHEDULER_MODE_PROPERTY, SchedulingMode.FIFO.toString) val schedulingMode: SchedulingMode = try { SchedulingMode.withName(schedulingModeConf.toUpperCase) } catch { case e: java.util.NoSuchElementException => - throw new SparkException(s"Unrecognized spark.scheduler.mode: $schedulingModeConf") + throw new SparkException(s"Unrecognized $SCHEDULER_MODE_PROPERTY: $schedulingModeConf") } + val rootPool: Pool = new Pool("", schedulingMode, 0, 0) + // This is a var so that we can reset it for testing purposes. private[spark] var taskResultGetter = new TaskResultGetter(sc.env, this) @@ -150,8 +153,6 @@ private[spark] class TaskSchedulerImpl private[scheduler]( def initialize(backend: SchedulerBackend) { this.backend = backend - // temporarily set rootPool name to empty - rootPool = new Pool("", schedulingMode, 0, 0) schedulableBuilder = { schedulingMode match { case SchedulingMode.FIFO => @@ -159,7 +160,8 @@ private[spark] class TaskSchedulerImpl private[scheduler]( case SchedulingMode.FAIR => new FairSchedulableBuilder(rootPool, conf) case _ => - throw new IllegalArgumentException(s"Unsupported spark.scheduler.mode: $schedulingMode") + throw new IllegalArgumentException(s"Unsupported $SCHEDULER_MODE_PROPERTY: " + + s"$schedulingMode") } } schedulableBuilder.buildPools() @@ -683,6 +685,9 @@ private[spark] class TaskSchedulerImpl private[scheduler]( private[spark] object TaskSchedulerImpl { + + val SCHEDULER_MODE_PROPERTY = "spark.scheduler.mode" + /** * Used to balance containers across hosts. * diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 11633bef3cfc..fd93a1f5c5d2 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -78,16 +78,16 @@ private[spark] class TaskSetManager( private val numFailures = new Array[Int](numTasks) val taskAttempts = Array.fill[List[TaskInfo]](numTasks)(Nil) - var tasksSuccessful = 0 + private[scheduler] var tasksSuccessful = 0 - var weight = 1 - var minShare = 0 + val weight = 1 + val minShare = 0 var priority = taskSet.priority var stageId = taskSet.stageId val name = "TaskSet_" + taskSet.id var parent: Pool = null - var totalResultSize = 0L - var calculatedTasks = 0 + private var totalResultSize = 0L + private var calculatedTasks = 0 private[scheduler] val taskSetBlacklistHelperOpt: Option[TaskSetBlacklist] = { blacklistTracker.map { _ => @@ -95,7 +95,7 @@ private[spark] class TaskSetManager( } } - val runningTasksSet = new HashSet[Long] + private[scheduler] val runningTasksSet = new HashSet[Long] override def runningTasks: Int = runningTasksSet.size @@ -105,7 +105,7 @@ private[spark] class TaskSetManager( // state until all tasks have finished running; we keep TaskSetManagers that are in the zombie // state in order to continue to track and account for the running tasks. // TODO: We should kill any running task attempts when the task set manager becomes a zombie. - var isZombie = false + private[scheduler] var isZombie = false // Set of pending tasks for each executor. These collections are actually // treated as stacks, in which new tasks are added to the end of the @@ -129,17 +129,17 @@ private[spark] class TaskSetManager( private val pendingTasksForRack = new HashMap[String, ArrayBuffer[Int]] // Set containing pending tasks with no locality preferences. - var pendingTasksWithNoPrefs = new ArrayBuffer[Int] + private[scheduler] var pendingTasksWithNoPrefs = new ArrayBuffer[Int] // Set containing all pending tasks (also used as a stack, as above). - val allPendingTasks = new ArrayBuffer[Int] + private val allPendingTasks = new ArrayBuffer[Int] // Tasks that can be speculated. Since these will be a small fraction of total // tasks, we'll just hold them in a HashSet. - val speculatableTasks = new HashSet[Int] + private[scheduler] val speculatableTasks = new HashSet[Int] // Task index, start and finish time for each task attempt (indexed by task ID) - val taskInfos = new HashMap[Long, TaskInfo] + private val taskInfos = new HashMap[Long, TaskInfo] // How frequently to reprint duplicate exceptions in full, in milliseconds val EXCEPTION_PRINT_INTERVAL = @@ -148,7 +148,7 @@ private[spark] class TaskSetManager( // Map of recent exceptions (identified by string representation and top stack frame) to // duplicate count (how many times the same exception has appeared) and time the full exception // was printed. This should ideally be an LRU map that can drop old exceptions automatically. - val recentExceptions = HashMap[String, (Int, Long)]() + private val recentExceptions = HashMap[String, (Int, Long)]() // Figure out the current map output tracker epoch and set it on all tasks val epoch = sched.mapOutputTracker.getEpoch @@ -169,20 +169,22 @@ private[spark] class TaskSetManager( * This allows a performance optimization, of skipping levels that aren't relevant (eg., skip * PROCESS_LOCAL if no tasks could be run PROCESS_LOCAL for the current set of executors). */ - var myLocalityLevels = computeValidLocalityLevels() - var localityWaits = myLocalityLevels.map(getLocalityWait) // Time to wait at each level + private[scheduler] var myLocalityLevels = computeValidLocalityLevels() + + // Time to wait at each level + private[scheduler] var localityWaits = myLocalityLevels.map(getLocalityWait) // Delay scheduling variables: we keep track of our current locality level and the time we // last launched a task at that level, and move up a level when localityWaits[curLevel] expires. // We then move down if we manage to launch a "more local" task. - var currentLocalityIndex = 0 // Index of our current locality level in validLocalityLevels - var lastLaunchTime = clock.getTimeMillis() // Time we last launched a task at this level + private var currentLocalityIndex = 0 // Index of our current locality level in validLocalityLevels + private var lastLaunchTime = clock.getTimeMillis() // Time we last launched a task at this level override def schedulableQueue: ConcurrentLinkedQueue[Schedulable] = null override def schedulingMode: SchedulingMode = SchedulingMode.NONE - var emittedTaskSizeWarning = false + private[scheduler] var emittedTaskSizeWarning = false /** Add a task to all the pending-task lists that it should be on. */ private def addPendingTask(index: Int) { diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index dfad5db68a91..a9389003d5db 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -110,8 +110,8 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou val cancelledStages = new HashSet[Int]() val taskScheduler = new TaskScheduler() { - override def rootPool: Pool = null - override def schedulingMode: SchedulingMode = SchedulingMode.NONE + override def schedulingMode: SchedulingMode = SchedulingMode.FIFO + override def rootPool: Pool = new Pool("", schedulingMode, 0, 0) override def start() = {} override def stop() = {} override def executorHeartbeatReceived( @@ -542,8 +542,8 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou // make sure that the DAGScheduler doesn't crash when the TaskScheduler // doesn't implement killTask() val noKillTaskScheduler = new TaskScheduler() { - override def rootPool: Pool = null - override def schedulingMode: SchedulingMode = SchedulingMode.NONE + override def schedulingMode: SchedulingMode = SchedulingMode.FIFO + override def rootPool: Pool = new Pool("", schedulingMode, 0, 0) override def start(): Unit = {} override def stop(): Unit = {} override def submitTasks(taskSet: TaskSet): Unit = { diff --git a/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala index e87cebf0cf35..37c124a726be 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala @@ -73,8 +73,8 @@ private class DummySchedulerBackend extends SchedulerBackend { private class DummyTaskScheduler extends TaskScheduler { var initialized = false - override def rootPool: Pool = null - override def schedulingMode: SchedulingMode = SchedulingMode.NONE + override def schedulingMode: SchedulingMode = SchedulingMode.FIFO + override def rootPool: Pool = new Pool("", schedulingMode, 0, 0) override def start(): Unit = {} override def stop(): Unit = {} override def submitTasks(taskSet: TaskSet): Unit = {} diff --git a/core/src/test/scala/org/apache/spark/scheduler/PoolSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/PoolSuite.scala index cddff3dd3586..4901062a7855 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/PoolSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/PoolSuite.scala @@ -286,6 +286,12 @@ class PoolSuite extends SparkFunSuite with LocalSparkContext { assert(testPool.getSchedulableByName(taskSetManager.name) === taskSetManager) } + test("Pool should throw IllegalArgumentException when schedulingMode is not supported") { + intercept[IllegalArgumentException] { + new Pool("TestPool", SchedulingMode.NONE, 0, 1) + } + } + private def verifyPool(rootPool: Pool, poolName: String, expectedInitMinShare: Int, expectedInitWeight: Int, expectedSchedulingMode: SchedulingMode): Unit = { val selectedPool = rootPool.getSchedulableByName(poolName) diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala index 9ae0bcd9b886..8b9d45f734cd 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala @@ -75,9 +75,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B def setupScheduler(confs: (String, String)*): TaskSchedulerImpl = { val conf = new SparkConf().setMaster("local").setAppName("TaskSchedulerImplSuite") - confs.foreach { case (k, v) => - conf.set(k, v) - } + confs.foreach { case (k, v) => conf.set(k, v) } sc = new SparkContext(conf) taskScheduler = new TaskSchedulerImpl(sc) setupHelper() @@ -904,4 +902,12 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B assert(taskDescs.size === 1) assert(taskDescs.head.executorId === "exec2") } + + test("TaskScheduler should throw IllegalArgumentException when schedulingMode is not supported") { + intercept[IllegalArgumentException] { + val taskScheduler = setupScheduler( + TaskSchedulerImpl.SCHEDULER_MODE_PROPERTY -> SchedulingMode.NONE.toString) + taskScheduler.initialize(new FakeSchedulerBackend) + } + } } From c7911807050227fcd13161ce090330d9d8daa533 Mon Sep 17 00:00:00 2001 From: sureshthalamati Date: Thu, 23 Mar 2017 17:39:33 -0700 Subject: [PATCH 113/512] [SPARK-10849][SQL] Adds option to the JDBC data source write for user to specify database column type for the create table MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? Currently JDBC data source creates tables in the target database using the default type mapping, and the JDBC dialect mechanism.  If users want to specify different database data type for only some of columns, there is no option available. In scenarios where default mapping does not work, users are forced to create tables on the target database before writing. This workaround is probably not acceptable from a usability point of view. This PR is to provide a user-defined type mapping for specific columns. The solution is to allow users to specify database column data type for the create table as JDBC datasource option(createTableColumnTypes) on write. Data type information can be specified in the same format as table schema DDL format (e.g: `name CHAR(64), comments VARCHAR(1024)`). All supported target database types can not be specified , the data types has to be valid spark sql data types also. For example user can not specify target database CLOB data type. This will be supported in the follow-up PR. Example: ```Scala df.write .option("createTableColumnTypes", "name CHAR(64), comments VARCHAR(1024)") .jdbc(url, "TEST.DBCOLTYPETEST", properties) ``` ## How was this patch tested? Added new test cases to the JDBCWriteSuite Author: sureshthalamati Closes #16209 from sureshthalamati/jdbc_custom_dbtype_option_json-spark-10849. --- docs/sql-programming-guide.md | 7 + .../sql/JavaSQLDataSourceExample.java | 5 + examples/src/main/python/sql/datasource.py | 6 + .../examples/sql/SQLDataSourceExample.scala | 5 + .../datasources/jdbc/JDBCOptions.scala | 2 + .../jdbc/JdbcRelationProvider.scala | 4 +- .../datasources/jdbc/JdbcUtils.scala | 66 +++++++- .../org/apache/spark/sql/jdbc/JDBCSuite.scala | 2 +- .../spark/sql/jdbc/JDBCWriteSuite.scala | 150 +++++++++++++++++- 9 files changed, 235 insertions(+), 12 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index b077575155eb..7ae9847983d4 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1223,6 +1223,13 @@ the following case-insensitive options: This is a JDBC writer related option. If specified, this option allows setting of database-specific table and partition options when creating a table (e.g., CREATE TABLE t (name string) ENGINE=InnoDB.). This option applies only to writing. + + + createTableColumnTypes + + The database column data types to use instead of the defaults, when creating the table. Data type information should be specified in the same format as CREATE TABLE columns syntax (e.g: "name CHAR(64), comments VARCHAR(1024)"). The specified types should be valid spark sql data types. This option applies only to writing. + +
      diff --git a/examples/src/main/java/org/apache/spark/examples/sql/JavaSQLDataSourceExample.java b/examples/src/main/java/org/apache/spark/examples/sql/JavaSQLDataSourceExample.java index 82bb284ea3e5..1a7054614b34 100644 --- a/examples/src/main/java/org/apache/spark/examples/sql/JavaSQLDataSourceExample.java +++ b/examples/src/main/java/org/apache/spark/examples/sql/JavaSQLDataSourceExample.java @@ -258,6 +258,11 @@ private static void runJdbcDatasetExample(SparkSession spark) { jdbcDF2.write() .jdbc("jdbc:postgresql:dbserver", "schema.tablename", connectionProperties); + + // Specifying create table column data types on write + jdbcDF.write() + .option("createTableColumnTypes", "name CHAR(64), comments VARCHAR(1024)") + .jdbc("jdbc:postgresql:dbserver", "schema.tablename", connectionProperties); // $example off:jdbc_dataset$ } } diff --git a/examples/src/main/python/sql/datasource.py b/examples/src/main/python/sql/datasource.py index e9aa9d9ac258..e4abb0933345 100644 --- a/examples/src/main/python/sql/datasource.py +++ b/examples/src/main/python/sql/datasource.py @@ -169,6 +169,12 @@ def jdbc_dataset_example(spark): jdbcDF2.write \ .jdbc("jdbc:postgresql:dbserver", "schema.tablename", properties={"user": "username", "password": "password"}) + + # Specifying create table column data types on write + jdbcDF.write \ + .option("createTableColumnTypes", "name CHAR(64), comments VARCHAR(1024)") \ + .jdbc("jdbc:postgresql:dbserver", "schema.tablename", + properties={"user": "username", "password": "password"}) # $example off:jdbc_dataset$ diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala b/examples/src/main/scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala index 381e69cda841..82fd56de3984 100644 --- a/examples/src/main/scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala @@ -181,6 +181,11 @@ object SQLDataSourceExample { jdbcDF2.write .jdbc("jdbc:postgresql:dbserver", "schema.tablename", connectionProperties) + + // Specifying create table column data types on write + jdbcDF.write + .option("createTableColumnTypes", "name CHAR(64), comments VARCHAR(1024)") + .jdbc("jdbc:postgresql:dbserver", "schema.tablename", connectionProperties) // $example off:jdbc_dataset$ } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala index d4d34646545b..89fe86c038b1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala @@ -119,6 +119,7 @@ class JDBCOptions( // E.g., "CREATE TABLE t (name string) ENGINE=InnoDB DEFAULT CHARSET=utf8" // TODO: to reuse the existing partition parameters for those partition specific options val createTableOptions = parameters.getOrElse(JDBC_CREATE_TABLE_OPTIONS, "") + val createTableColumnTypes = parameters.get(JDBC_CREATE_TABLE_COLUMN_TYPES) val batchSize = { val size = parameters.getOrElse(JDBC_BATCH_INSERT_SIZE, "1000").toInt require(size >= 1, @@ -154,6 +155,7 @@ object JDBCOptions { val JDBC_BATCH_FETCH_SIZE = newOption("fetchsize") val JDBC_TRUNCATE = newOption("truncate") val JDBC_CREATE_TABLE_OPTIONS = newOption("createTableOptions") + val JDBC_CREATE_TABLE_COLUMN_TYPES = newOption("createTableColumnTypes") val JDBC_BATCH_INSERT_SIZE = newOption("batchsize") val JDBC_TXN_ISOLATION_LEVEL = newOption("isolationLevel") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala index 88f6cb002130..74dcfb06f5c2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala @@ -69,7 +69,7 @@ class JdbcRelationProvider extends CreatableRelationProvider } else { // Otherwise, do not truncate the table, instead drop and recreate it dropTable(conn, options.table) - createTable(conn, df.schema, options) + createTable(conn, df, options) saveTable(df, Some(df.schema), isCaseSensitive, options) } @@ -87,7 +87,7 @@ class JdbcRelationProvider extends CreatableRelationProvider // Therefore, it is okay to do nothing here and then just return the relation below. } } else { - createTable(conn, df.schema, options) + createTable(conn, df, options) saveTable(df, Some(df.schema), isCaseSensitive, options) } } finally { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index d89f60087417..774d1ba19432 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -30,7 +30,8 @@ import org.apache.spark.sql.{AnalysisException, DataFrame, Row} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions.SpecificInternalRow -import org.apache.spark.sql.catalyst.util.{DateTimeUtils, GenericArrayData} +import org.apache.spark.sql.catalyst.parser.CatalystSqlParser +import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils, GenericArrayData} import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcDialects, JdbcType} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -680,18 +681,70 @@ object JdbcUtils extends Logging { /** * Compute the schema string for this RDD. */ - def schemaString(schema: StructType, url: String): String = { + def schemaString( + df: DataFrame, + url: String, + createTableColumnTypes: Option[String] = None): String = { val sb = new StringBuilder() val dialect = JdbcDialects.get(url) - schema.fields foreach { field => + val userSpecifiedColTypesMap = createTableColumnTypes + .map(parseUserSpecifiedCreateTableColumnTypes(df, _)) + .getOrElse(Map.empty[String, String]) + df.schema.fields.foreach { field => val name = dialect.quoteIdentifier(field.name) - val typ: String = getJdbcType(field.dataType, dialect).databaseTypeDefinition + val typ = userSpecifiedColTypesMap + .getOrElse(field.name, getJdbcType(field.dataType, dialect).databaseTypeDefinition) val nullable = if (field.nullable) "" else "NOT NULL" sb.append(s", $name $typ $nullable") } if (sb.length < 2) "" else sb.substring(2) } + /** + * Parses the user specified createTableColumnTypes option value string specified in the same + * format as create table ddl column types, and returns Map of field name and the data type to + * use in-place of the default data type. + */ + private def parseUserSpecifiedCreateTableColumnTypes( + df: DataFrame, + createTableColumnTypes: String): Map[String, String] = { + def typeName(f: StructField): String = { + // char/varchar gets translated to string type. Real data type specified by the user + // is available in the field metadata as HIVE_TYPE_STRING + if (f.metadata.contains(HIVE_TYPE_STRING)) { + f.metadata.getString(HIVE_TYPE_STRING) + } else { + f.dataType.catalogString + } + } + + val userSchema = CatalystSqlParser.parseTableSchema(createTableColumnTypes) + val nameEquality = df.sparkSession.sessionState.conf.resolver + + // checks duplicate columns in the user specified column types. + userSchema.fieldNames.foreach { col => + val duplicatesCols = userSchema.fieldNames.filter(nameEquality(_, col)) + if (duplicatesCols.size >= 2) { + throw new AnalysisException( + "Found duplicate column(s) in createTableColumnTypes option value: " + + duplicatesCols.mkString(", ")) + } + } + + // checks if user specified column names exist in the DataFrame schema + userSchema.fieldNames.foreach { col => + df.schema.find(f => nameEquality(f.name, col)).getOrElse { + throw new AnalysisException( + s"createTableColumnTypes option column $col not found in schema " + + df.schema.catalogString) + } + } + + val userSchemaMap = userSchema.fields.map(f => f.name -> typeName(f)).toMap + val isCaseSensitive = df.sparkSession.sessionState.conf.caseSensitiveAnalysis + if (isCaseSensitive) userSchemaMap else CaseInsensitiveMap(userSchemaMap) + } + /** * Saves the RDD to the database in a single transaction. */ @@ -726,9 +779,10 @@ object JdbcUtils extends Logging { */ def createTable( conn: Connection, - schema: StructType, + df: DataFrame, options: JDBCOptions): Unit = { - val strSchema = schemaString(schema, options.url) + val strSchema = schemaString( + df, options.url, options.createTableColumnTypes) val table = options.table val createTableOptions = options.createTableOptions // Create the table if the table does not exist. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 5463728ca0c1..4a02277631f1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -869,7 +869,7 @@ class JDBCSuite extends SparkFunSuite test("SPARK-16387: Reserved SQL words are not escaped by JDBC writer") { val df = spark.createDataset(Seq("a", "b", "c")).toDF("order") - val schema = JdbcUtils.schemaString(df.schema, "jdbc:mysql://localhost:3306/temp") + val schema = JdbcUtils.schemaString(df, "jdbc:mysql://localhost:3306/temp") assert(schema.contains("`order` TEXT")) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala index ec7b19e666ec..bf1fd160704f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala @@ -17,15 +17,16 @@ package org.apache.spark.sql.jdbc -import java.sql.DriverManager +import java.sql.{Date, DriverManager, Timestamp} import java.util.Properties import scala.collection.JavaConverters.propertiesAsScalaMapConverter import org.scalatest.BeforeAndAfter -import org.apache.spark.sql.{AnalysisException, Row, SaveMode} -import org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions +import org.apache.spark.sql.{AnalysisException, DataFrame, Row, SaveMode} +import org.apache.spark.sql.catalyst.parser.ParseException +import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JdbcUtils} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ @@ -362,4 +363,147 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter { assert(sql("select * from people_view").count() == 2) } } + + test("SPARK-10849: test schemaString - from createTableColumnTypes option values") { + def testCreateTableColDataTypes(types: Seq[String]): Unit = { + val colTypes = types.zipWithIndex.map { case (t, i) => (s"col$i", t) } + val schema = colTypes + .foldLeft(new StructType())((schema, colType) => schema.add(colType._1, colType._2)) + val createTableColTypes = + colTypes.map { case (col, dataType) => s"$col $dataType" }.mkString(", ") + val df = spark.createDataFrame(sparkContext.parallelize(Seq(Row.empty)), schema) + + val expectedSchemaStr = + colTypes.map { case (col, dataType) => s""""$col" $dataType """ }.mkString(", ") + + assert(JdbcUtils.schemaString(df, url1, Option(createTableColTypes)) == expectedSchemaStr) + } + + testCreateTableColDataTypes(Seq("boolean")) + testCreateTableColDataTypes(Seq("tinyint", "smallint", "int", "bigint")) + testCreateTableColDataTypes(Seq("float", "double")) + testCreateTableColDataTypes(Seq("string", "char(10)", "varchar(20)")) + testCreateTableColDataTypes(Seq("decimal(10,0)", "decimal(10,5)")) + testCreateTableColDataTypes(Seq("date", "timestamp")) + testCreateTableColDataTypes(Seq("binary")) + } + + test("SPARK-10849: create table using user specified column type and verify on target table") { + def testUserSpecifiedColTypes( + df: DataFrame, + createTableColTypes: String, + expectedTypes: Map[String, String]): Unit = { + df.write + .mode(SaveMode.Overwrite) + .option("createTableColumnTypes", createTableColTypes) + .jdbc(url1, "TEST.DBCOLTYPETEST", properties) + + // verify the data types of the created table by reading the database catalog of H2 + val query = + """ + |(SELECT column_name, type_name, character_maximum_length + | FROM information_schema.columns WHERE table_name = 'DBCOLTYPETEST') + """.stripMargin + val rows = spark.read.jdbc(url1, query, properties).collect() + + rows.foreach { row => + val typeName = row.getString(1) + // For CHAR and VARCHAR, we also compare the max length + if (typeName.contains("CHAR")) { + val charMaxLength = row.getInt(2) + assert(expectedTypes(row.getString(0)) == s"$typeName($charMaxLength)") + } else { + assert(expectedTypes(row.getString(0)) == typeName) + } + } + } + + val data = Seq[Row](Row(1, "dave", "Boston")) + val schema = StructType( + StructField("id", IntegerType) :: + StructField("first#name", StringType) :: + StructField("city", StringType) :: Nil) + val df = spark.createDataFrame(sparkContext.parallelize(data), schema) + + // out-of-order + val expected1 = Map("id" -> "BIGINT", "first#name" -> "VARCHAR(123)", "city" -> "CHAR(20)") + testUserSpecifiedColTypes(df, "`first#name` VARCHAR(123), id BIGINT, city CHAR(20)", expected1) + // partial schema + val expected2 = Map("id" -> "INTEGER", "first#name" -> "VARCHAR(123)", "city" -> "CHAR(20)") + testUserSpecifiedColTypes(df, "`first#name` VARCHAR(123), city CHAR(20)", expected2) + + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { + // should still respect the original column names + val expected = Map("id" -> "INTEGER", "first#name" -> "VARCHAR(123)", "city" -> "CLOB") + testUserSpecifiedColTypes(df, "`FiRsT#NaMe` VARCHAR(123)", expected) + } + + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + val schema = StructType( + StructField("id", IntegerType) :: + StructField("First#Name", StringType) :: + StructField("city", StringType) :: Nil) + val df = spark.createDataFrame(sparkContext.parallelize(data), schema) + val expected = Map("id" -> "INTEGER", "First#Name" -> "VARCHAR(123)", "city" -> "CLOB") + testUserSpecifiedColTypes(df, "`First#Name` VARCHAR(123)", expected) + } + } + + test("SPARK-10849: jdbc CreateTableColumnTypes option with invalid data type") { + val df = spark.createDataFrame(sparkContext.parallelize(arr2x2), schema2) + val msg = intercept[ParseException] { + df.write.mode(SaveMode.Overwrite) + .option("createTableColumnTypes", "name CLOB(2000)") + .jdbc(url1, "TEST.USERDBTYPETEST", properties) + }.getMessage() + assert(msg.contains("DataType clob(2000) is not supported.")) + } + + test("SPARK-10849: jdbc CreateTableColumnTypes option with invalid syntax") { + val df = spark.createDataFrame(sparkContext.parallelize(arr2x2), schema2) + val msg = intercept[ParseException] { + df.write.mode(SaveMode.Overwrite) + .option("createTableColumnTypes", "`name char(20)") // incorrectly quoted column + .jdbc(url1, "TEST.USERDBTYPETEST", properties) + }.getMessage() + assert(msg.contains("no viable alternative at input")) + } + + test("SPARK-10849: jdbc CreateTableColumnTypes duplicate columns") { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { + val df = spark.createDataFrame(sparkContext.parallelize(arr2x2), schema2) + val msg = intercept[AnalysisException] { + df.write.mode(SaveMode.Overwrite) + .option("createTableColumnTypes", "name CHAR(20), id int, NaMe VARCHAR(100)") + .jdbc(url1, "TEST.USERDBTYPETEST", properties) + }.getMessage() + assert(msg.contains( + "Found duplicate column(s) in createTableColumnTypes option value: name, NaMe")) + } + } + + test("SPARK-10849: jdbc CreateTableColumnTypes invalid columns") { + // schema2 has the column "id" and "name" + val df = spark.createDataFrame(sparkContext.parallelize(arr2x2), schema2) + + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { + val msg = intercept[AnalysisException] { + df.write.mode(SaveMode.Overwrite) + .option("createTableColumnTypes", "firstName CHAR(20), id int") + .jdbc(url1, "TEST.USERDBTYPETEST", properties) + }.getMessage() + assert(msg.contains("createTableColumnTypes option column firstName not found in " + + "schema struct")) + } + + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + val msg = intercept[AnalysisException] { + df.write.mode(SaveMode.Overwrite) + .option("createTableColumnTypes", "id int, Name VARCHAR(100)") + .jdbc(url1, "TEST.USERDBTYPETEST", properties) + }.getMessage() + assert(msg.contains("createTableColumnTypes option column Name not found in " + + "schema struct")) + } + } } From 93581fbc18c01595918c565f6737aaa666116114 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Thu, 23 Mar 2017 17:57:31 -0700 Subject: [PATCH 114/512] Fix compilation of the Scala 2.10 master branch ## What changes were proposed in this pull request? Fixes break caused by: https://github.com/apache/spark/commit/746a558de2136f91f8fe77c6e51256017aa50913 ## How was this patch tested? Compiled with `build/sbt -Dscala2.10 sql/compile` locally Author: Burak Yavuz Closes #17403 from brkyvz/onceTrigger2.10. --- .../spark/sql/streaming/ProcessingTime.scala | 20 +++++++++---------- .../apache/spark/sql/streaming/Trigger.java | 2 +- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/ProcessingTime.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/ProcessingTime.scala index bdad8e4717be..9ba1fc01cbd3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/ProcessingTime.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/ProcessingTime.scala @@ -51,7 +51,7 @@ import org.apache.spark.unsafe.types.CalendarInterval */ @Experimental @InterfaceStability.Evolving -@deprecated("use Trigger.ProcessingTimeTrigger(intervalMs)", "2.2.0") +@deprecated("use Trigger.ProcessingTime(intervalMs)", "2.2.0") case class ProcessingTime(intervalMs: Long) extends Trigger { require(intervalMs >= 0, "the interval of trigger should not be negative") } @@ -64,7 +64,7 @@ case class ProcessingTime(intervalMs: Long) extends Trigger { */ @Experimental @InterfaceStability.Evolving -@deprecated("use Trigger.ProcessingTimeTrigger(intervalMs)", "2.2.0") +@deprecated("use Trigger.ProcessingTime(intervalMs)", "2.2.0") object ProcessingTime { /** @@ -76,9 +76,9 @@ object ProcessingTime { * }}} * * @since 2.0.0 - * @deprecated use Trigger.ProcessingTimeTrigger(interval) + * @deprecated use Trigger.ProcessingTime(interval) */ - @deprecated("use Trigger.ProcessingTimeTrigger(interval)", "2.2.0") + @deprecated("use Trigger.ProcessingTime(interval)", "2.2.0") def apply(interval: String): ProcessingTime = { if (StringUtils.isBlank(interval)) { throw new IllegalArgumentException( @@ -108,9 +108,9 @@ object ProcessingTime { * }}} * * @since 2.0.0 - * @deprecated use Trigger.ProcessingTimeTrigger(interval) + * @deprecated use Trigger.ProcessingTime(interval) */ - @deprecated("use Trigger.ProcessingTimeTrigger(interval)", "2.2.0") + @deprecated("use Trigger.ProcessingTime(interval)", "2.2.0") def apply(interval: Duration): ProcessingTime = { new ProcessingTime(interval.toMillis) } @@ -124,9 +124,9 @@ object ProcessingTime { * }}} * * @since 2.0.0 - * @deprecated use Trigger.ProcessingTimeTrigger(interval) + * @deprecated use Trigger.ProcessingTime(interval) */ - @deprecated("use Trigger.ProcessingTimeTrigger(interval)", "2.2.0") + @deprecated("use Trigger.ProcessingTime(interval)", "2.2.0") def create(interval: String): ProcessingTime = { apply(interval) } @@ -141,9 +141,9 @@ object ProcessingTime { * }}} * * @since 2.0.0 - * @deprecated use Trigger.ProcessingTimeTrigger(interval) + * @deprecated use Trigger.ProcessingTime(interval, unit) */ - @deprecated("use Trigger.ProcessingTimeTrigger(interval, unit)", "2.2.0") + @deprecated("use Trigger.ProcessingTime(interval, unit)", "2.2.0") def create(interval: Long, unit: TimeUnit): ProcessingTime = { new ProcessingTime(unit.toMillis(interval)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/Trigger.java b/sql/core/src/main/scala/org/apache/spark/sql/streaming/Trigger.java index a03a851f245f..3e3997fa9bfe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/Trigger.java +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/Trigger.java @@ -43,7 +43,7 @@ public class Trigger { * @since 2.2.0 */ public static Trigger ProcessingTime(long intervalMs) { - return ProcessingTime.apply(intervalMs); + return ProcessingTime.create(intervalMs, TimeUnit.MILLISECONDS); } /** From d27daa54bd341b29737a6352d9a1055151248ae7 Mon Sep 17 00:00:00 2001 From: Timothy Hunter Date: Thu, 23 Mar 2017 18:42:13 -0700 Subject: [PATCH 115/512] [SPARK-19636][ML] Feature parity for correlation statistics in MLlib ## What changes were proposed in this pull request? This patch adds the Dataframes-based support for the correlation statistics found in the `org.apache.spark.mllib.stat.correlation.Statistics`, following the design doc discussed in the JIRA ticket. The current implementation is a simple wrapper around the `spark.mllib` implementation. Future optimizations can be implemented at a later stage. ## How was this patch tested? ``` build/sbt "testOnly org.apache.spark.ml.stat.StatisticsSuite" ``` Author: Timothy Hunter Closes #17108 from thunterdb/19636. --- .../apache/spark/ml/util/TestingUtils.scala | 8 ++ .../apache/spark/ml/stat/Correlation.scala | 86 +++++++++++++++++++ .../spark/ml/stat/CorrelationSuite.scala | 77 +++++++++++++++++ 3 files changed, 171 insertions(+) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/stat/Correlation.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/stat/CorrelationSuite.scala diff --git a/mllib-local/src/test/scala/org/apache/spark/ml/util/TestingUtils.scala b/mllib-local/src/test/scala/org/apache/spark/ml/util/TestingUtils.scala index 2327917e2cad..30edd00fb53e 100644 --- a/mllib-local/src/test/scala/org/apache/spark/ml/util/TestingUtils.scala +++ b/mllib-local/src/test/scala/org/apache/spark/ml/util/TestingUtils.scala @@ -32,6 +32,10 @@ object TestingUtils { * the relative tolerance is meaningless, so the exception will be raised to warn users. */ private def RelativeErrorComparison(x: Double, y: Double, eps: Double): Boolean = { + // Special case for NaNs + if (x.isNaN && y.isNaN) { + return true + } val absX = math.abs(x) val absY = math.abs(y) val diff = math.abs(x - y) @@ -49,6 +53,10 @@ object TestingUtils { * Private helper function for comparing two values using absolute tolerance. */ private def AbsoluteErrorComparison(x: Double, y: Double, eps: Double): Boolean = { + // Special case for NaNs + if (x.isNaN && y.isNaN) { + return true + } math.abs(x - y) < eps } diff --git a/mllib/src/main/scala/org/apache/spark/ml/stat/Correlation.scala b/mllib/src/main/scala/org/apache/spark/ml/stat/Correlation.scala new file mode 100644 index 000000000000..a7243ccbf28c --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/stat/Correlation.scala @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.stat + +import scala.collection.JavaConverters._ + +import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.ml.linalg.{SQLDataTypes, Vector} +import org.apache.spark.mllib.linalg.{Vectors => OldVectors} +import org.apache.spark.mllib.stat.{Statistics => OldStatistics} +import org.apache.spark.sql.{DataFrame, Dataset, Row} +import org.apache.spark.sql.types.{StructField, StructType} + +/** + * API for correlation functions in MLlib, compatible with Dataframes and Datasets. + * + * The functions in this package generalize the functions in [[org.apache.spark.sql.Dataset.stat]] + * to spark.ml's Vector types. + */ +@Since("2.2.0") +@Experimental +object Correlation { + + /** + * :: Experimental :: + * Compute the correlation matrix for the input RDD of Vectors using the specified method. + * Methods currently supported: `pearson` (default), `spearman`. + * + * @param dataset A dataset or a dataframe + * @param column The name of the column of vectors for which the correlation coefficient needs + * to be computed. This must be a column of the dataset, and it must contain + * Vector objects. + * @param method String specifying the method to use for computing correlation. + * Supported: `pearson` (default), `spearman` + * @return A dataframe that contains the correlation matrix of the column of vectors. This + * dataframe contains a single row and a single column of name + * '$METHODNAME($COLUMN)'. + * @throws IllegalArgumentException if the column is not a valid column in the dataset, or if + * the content of this column is not of type Vector. + * + * Here is how to access the correlation coefficient: + * {{{ + * val data: Dataset[Vector] = ... + * val Row(coeff: Matrix) = Statistics.corr(data, "value").head + * // coeff now contains the Pearson correlation matrix. + * }}} + * + * @note For Spearman, a rank correlation, we need to create an RDD[Double] for each column + * and sort it in order to retrieve the ranks and then join the columns back into an RDD[Vector], + * which is fairly costly. Cache the input RDD before calling corr with `method = "spearman"` to + * avoid recomputing the common lineage. + */ + @Since("2.2.0") + def corr(dataset: Dataset[_], column: String, method: String): DataFrame = { + val rdd = dataset.select(column).rdd.map { + case Row(v: Vector) => OldVectors.fromML(v) + } + val oldM = OldStatistics.corr(rdd, method) + val name = s"$method($column)" + val schema = StructType(Array(StructField(name, SQLDataTypes.MatrixType, nullable = false))) + dataset.sparkSession.createDataFrame(Seq(Row(oldM.asML)).asJava, schema) + } + + /** + * Compute the Pearson correlation matrix for the input Dataset of Vectors. + */ + @Since("2.2.0") + def corr(dataset: Dataset[_], column: String): DataFrame = { + corr(dataset, column, "pearson") + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/stat/CorrelationSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/stat/CorrelationSuite.scala new file mode 100644 index 000000000000..7d935e651f22 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/stat/CorrelationSuite.scala @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.stat + +import breeze.linalg.{DenseMatrix => BDM} + +import org.apache.spark.SparkFunSuite +import org.apache.spark.internal.Logging +import org.apache.spark.ml.linalg.{Matrices, Matrix, Vectors} +import org.apache.spark.ml.util.TestingUtils._ +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.{DataFrame, Row} + + +class CorrelationSuite extends SparkFunSuite with MLlibTestSparkContext with Logging { + + val xData = Array(1.0, 0.0, -2.0) + val yData = Array(4.0, 5.0, 3.0) + val zeros = new Array[Double](3) + val data = Seq( + Vectors.dense(1.0, 0.0, 0.0, -2.0), + Vectors.dense(4.0, 5.0, 0.0, 3.0), + Vectors.dense(6.0, 7.0, 0.0, 8.0), + Vectors.dense(9.0, 0.0, 0.0, 1.0) + ) + + private def X = spark.createDataFrame(data.map(Tuple1.apply)).toDF("features") + + private def extract(df: DataFrame): BDM[Double] = { + val Array(Row(mat: Matrix)) = df.collect() + mat.asBreeze.toDenseMatrix + } + + + test("corr(X) default, pearson") { + val defaultMat = Correlation.corr(X, "features") + val pearsonMat = Correlation.corr(X, "features", "pearson") + // scalastyle:off + val expected = Matrices.fromBreeze(BDM( + (1.00000000, 0.05564149, Double.NaN, 0.4004714), + (0.05564149, 1.00000000, Double.NaN, 0.9135959), + (Double.NaN, Double.NaN, 1.00000000, Double.NaN), + (0.40047142, 0.91359586, Double.NaN, 1.0000000))) + // scalastyle:on + + assert(Matrices.fromBreeze(extract(defaultMat)) ~== expected absTol 1e-4) + assert(Matrices.fromBreeze(extract(pearsonMat)) ~== expected absTol 1e-4) + } + + test("corr(X) spearman") { + val spearmanMat = Correlation.corr(X, "features", "spearman") + // scalastyle:off + val expected = Matrices.fromBreeze(BDM( + (1.0000000, 0.1054093, Double.NaN, 0.4000000), + (0.1054093, 1.0000000, Double.NaN, 0.9486833), + (Double.NaN, Double.NaN, 1.00000000, Double.NaN), + (0.4000000, 0.9486833, Double.NaN, 1.0000000))) + // scalastyle:on + assert(Matrices.fromBreeze(extract(spearmanMat)) ~== expected absTol 1e-4) + } + +} From bb823ca4b479a00030c4919c2d857d254b2a44d8 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Fri, 24 Mar 2017 12:57:56 +0800 Subject: [PATCH 116/512] [SPARK-19959][SQL] Fix to throw NullPointerException in df[java.lang.Long].collect ## What changes were proposed in this pull request? This PR fixes `NullPointerException` in the generated code by Catalyst. When we run the following code, we get the following `NullPointerException`. This is because there is no null checks for `inputadapter_value` while `java.lang.Long inputadapter_value` at Line 30 may have `null`. This happen when a type of DataFrame is nullable primitive type such as `java.lang.Long` and the wholestage codegen is used. While the physical plan keeps `nullable=true` in `input[0, java.lang.Long, true].longValue`, `BoundReference.doGenCode` ignores `nullable=true`. Thus, nullcheck code will not be generated and `NullPointerException` will occur. This PR checks the nullability and correctly generates nullcheck if needed. ```java sparkContext.parallelize(Seq[java.lang.Long](0L, null, 2L), 1).toDF.collect ``` ```java Caused by: java.lang.NullPointerException at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIterator.processNext(generated.java:37) at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43) at org.apache.spark.sql.execution.WholeStageCodegenExec$$anonfun$8$$anon$1.hasNext(WholeStageCodegenExec.scala:393) ... ``` Generated code without this PR ```java /* 005 */ final class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator { /* 006 */ private Object[] references; /* 007 */ private scala.collection.Iterator[] inputs; /* 008 */ private scala.collection.Iterator inputadapter_input; /* 009 */ private UnsafeRow serializefromobject_result; /* 010 */ private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder serializefromobject_holder; /* 011 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter serializefromobject_rowWriter; /* 012 */ /* 013 */ public GeneratedIterator(Object[] references) { /* 014 */ this.references = references; /* 015 */ } /* 016 */ /* 017 */ public void init(int index, scala.collection.Iterator[] inputs) { /* 018 */ partitionIndex = index; /* 019 */ this.inputs = inputs; /* 020 */ inputadapter_input = inputs[0]; /* 021 */ serializefromobject_result = new UnsafeRow(1); /* 022 */ this.serializefromobject_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(serializefromobject_result, 0); /* 023 */ this.serializefromobject_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(serializefromobject_holder, 1); /* 024 */ /* 025 */ } /* 026 */ /* 027 */ protected void processNext() throws java.io.IOException { /* 028 */ while (inputadapter_input.hasNext() && !stopEarly()) { /* 029 */ InternalRow inputadapter_row = (InternalRow) inputadapter_input.next(); /* 030 */ java.lang.Long inputadapter_value = (java.lang.Long)inputadapter_row.get(0, null); /* 031 */ /* 032 */ boolean serializefromobject_isNull = true; /* 033 */ long serializefromobject_value = -1L; /* 034 */ if (!false) { /* 035 */ serializefromobject_isNull = false; /* 036 */ if (!serializefromobject_isNull) { /* 037 */ serializefromobject_value = inputadapter_value.longValue(); /* 038 */ } /* 039 */ /* 040 */ } /* 041 */ serializefromobject_rowWriter.zeroOutNullBytes(); /* 042 */ /* 043 */ if (serializefromobject_isNull) { /* 044 */ serializefromobject_rowWriter.setNullAt(0); /* 045 */ } else { /* 046 */ serializefromobject_rowWriter.write(0, serializefromobject_value); /* 047 */ } /* 048 */ append(serializefromobject_result); /* 049 */ if (shouldStop()) return; /* 050 */ } /* 051 */ } /* 052 */ } ``` Generated code with this PR ```java /* 005 */ final class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator { /* 006 */ private Object[] references; /* 007 */ private scala.collection.Iterator[] inputs; /* 008 */ private scala.collection.Iterator inputadapter_input; /* 009 */ private UnsafeRow serializefromobject_result; /* 010 */ private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder serializefromobject_holder; /* 011 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter serializefromobject_rowWriter; /* 012 */ /* 013 */ public GeneratedIterator(Object[] references) { /* 014 */ this.references = references; /* 015 */ } /* 016 */ /* 017 */ public void init(int index, scala.collection.Iterator[] inputs) { /* 018 */ partitionIndex = index; /* 019 */ this.inputs = inputs; /* 020 */ inputadapter_input = inputs[0]; /* 021 */ serializefromobject_result = new UnsafeRow(1); /* 022 */ this.serializefromobject_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(serializefromobject_result, 0); /* 023 */ this.serializefromobject_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(serializefromobject_holder, 1); /* 024 */ /* 025 */ } /* 026 */ /* 027 */ protected void processNext() throws java.io.IOException { /* 028 */ while (inputadapter_input.hasNext() && !stopEarly()) { /* 029 */ InternalRow inputadapter_row = (InternalRow) inputadapter_input.next(); /* 030 */ boolean inputadapter_isNull = inputadapter_row.isNullAt(0); /* 031 */ java.lang.Long inputadapter_value = inputadapter_isNull ? null : ((java.lang.Long)inputadapter_row.get(0, null)); /* 032 */ /* 033 */ boolean serializefromobject_isNull = true; /* 034 */ long serializefromobject_value = -1L; /* 035 */ if (!inputadapter_isNull) { /* 036 */ serializefromobject_isNull = false; /* 037 */ if (!serializefromobject_isNull) { /* 038 */ serializefromobject_value = inputadapter_value.longValue(); /* 039 */ } /* 040 */ /* 041 */ } /* 042 */ serializefromobject_rowWriter.zeroOutNullBytes(); /* 043 */ /* 044 */ if (serializefromobject_isNull) { /* 045 */ serializefromobject_rowWriter.setNullAt(0); /* 046 */ } else { /* 047 */ serializefromobject_rowWriter.write(0, serializefromobject_value); /* 048 */ } /* 049 */ append(serializefromobject_result); /* 050 */ if (shouldStop()) return; /* 051 */ } /* 052 */ } /* 053 */ } ``` ## How was this patch tested? Added new test suites in `DataFrameSuites` Author: Kazuaki Ishizaki Closes #17302 from kiszk/SPARK-19959. --- .../spark/sql/catalyst/plans/logical/object.scala | 5 ++++- .../apache/spark/sql/DataFrameImplicitsSuite.scala | 11 +++++++++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala index 6225b3fa4299..bfb70c2ef4c8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala @@ -41,7 +41,10 @@ object CatalystSerde { } def generateObjAttr[T : Encoder]: Attribute = { - AttributeReference("obj", encoderFor[T].deserializer.dataType, nullable = false)() + val enc = encoderFor[T] + val dataType = enc.deserializer.dataType + val nullable = !enc.clsTag.runtimeClass.isPrimitive + AttributeReference("obj", dataType, nullable)() } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala index 094efbaeadcd..63094d1b6122 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala @@ -51,4 +51,15 @@ class DataFrameImplicitsSuite extends QueryTest with SharedSQLContext { sparkContext.parallelize(1 to 10).map(_.toString).toDF("stringCol"), (1 to 10).map(i => Row(i.toString))) } + + test("SPARK-19959: df[java.lang.Long].collect includes null throws NullPointerException") { + checkAnswer(sparkContext.parallelize(Seq[java.lang.Integer](0, null, 2), 1).toDF, + Seq(Row(0), Row(null), Row(2))) + checkAnswer(sparkContext.parallelize(Seq[java.lang.Long](0L, null, 2L), 1).toDF, + Seq(Row(0L), Row(null), Row(2L))) + checkAnswer(sparkContext.parallelize(Seq[java.lang.Float](0.0F, null, 2.0F), 1).toDF, + Seq(Row(0.0F), Row(null), Row(2.0F))) + checkAnswer(sparkContext.parallelize(Seq[java.lang.Double](0.0D, null, 2.0D), 1).toDF, + Seq(Row(0.0D), Row(null), Row(2.0D))) + } } From 19596c28b6ef6e7abe0cfccfd2269c2fddf1fdee Mon Sep 17 00:00:00 2001 From: jinxing Date: Thu, 23 Mar 2017 23:25:56 -0700 Subject: [PATCH 117/512] [SPARK-16929] Improve performance when check speculatable tasks. ## What changes were proposed in this pull request? 1. Use a MedianHeap to record durations of successful tasks. When check speculatable tasks, we can get the median duration with O(1) time complexity. 2. `checkSpeculatableTasks` will synchronize `TaskSchedulerImpl`. If `checkSpeculatableTasks` doesn't finish with 100ms, then the possibility exists for that thread to release and then immediately re-acquire the lock. Change `scheduleAtFixedRate` to be `scheduleWithFixedDelay` when call method of `checkSpeculatableTasks`. ## How was this patch tested? Added MedianHeapSuite. Author: jinxing Closes #16867 from jinxing64/SPARK-16929. --- .../spark/scheduler/TaskSchedulerImpl.scala | 2 +- .../spark/scheduler/TaskSetManager.scala | 19 +++- .../spark/util/collection/MedianHeap.scala | 93 +++++++++++++++++++ .../spark/scheduler/TaskSetManagerSuite.scala | 2 + .../util/collection/MedianHeapSuite.scala | 66 +++++++++++++ 5 files changed, 176 insertions(+), 6 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/util/collection/MedianHeap.scala create mode 100644 core/src/test/scala/org/apache/spark/util/collection/MedianHeapSuite.scala diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index 8257c70d672a..d6225a08739d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -174,7 +174,7 @@ private[spark] class TaskSchedulerImpl private[scheduler]( if (!isLocal && conf.getBoolean("spark.speculation", false)) { logInfo("Starting speculative execution thread") - speculationScheduler.scheduleAtFixedRate(new Runnable { + speculationScheduler.scheduleWithFixedDelay(new Runnable { override def run(): Unit = Utils.tryOrStopSparkContext(sc) { checkSpeculatableTasks() } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index fd93a1f5c5d2..f4a21bca79aa 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -19,11 +19,10 @@ package org.apache.spark.scheduler import java.io.NotSerializableException import java.nio.ByteBuffer -import java.util.Arrays import java.util.concurrent.ConcurrentLinkedQueue import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} -import scala.math.{max, min} +import scala.math.max import scala.util.control.NonFatal import org.apache.spark._ @@ -31,6 +30,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.scheduler.SchedulingMode._ import org.apache.spark.TaskState.TaskState import org.apache.spark.util.{AccumulatorV2, Clock, SystemClock, Utils} +import org.apache.spark.util.collection.MedianHeap /** * Schedules the tasks within a single TaskSet in the TaskSchedulerImpl. This class keeps track of @@ -63,6 +63,8 @@ private[spark] class TaskSetManager( // Limit of bytes for total size of results (default is 1GB) val maxResultSize = Utils.getMaxResultSize(conf) + val speculationEnabled = conf.getBoolean("spark.speculation", false) + // Serializer for closures and tasks. val env = SparkEnv.get val ser = env.closureSerializer.newInstance() @@ -141,6 +143,11 @@ private[spark] class TaskSetManager( // Task index, start and finish time for each task attempt (indexed by task ID) private val taskInfos = new HashMap[Long, TaskInfo] + // Use a MedianHeap to record durations of successful tasks so we know when to launch + // speculative tasks. This is only used when speculation is enabled, to avoid the overhead + // of inserting into the heap when the heap won't be used. + val successfulTaskDurations = new MedianHeap() + // How frequently to reprint duplicate exceptions in full, in milliseconds val EXCEPTION_PRINT_INTERVAL = conf.getLong("spark.logging.exceptionPrintInterval", 10000) @@ -698,6 +705,9 @@ private[spark] class TaskSetManager( val info = taskInfos(tid) val index = info.index info.markFinished(TaskState.FINISHED, clock.getTimeMillis()) + if (speculationEnabled) { + successfulTaskDurations.insert(info.duration) + } removeRunningTask(tid) // This method is called by "TaskSchedulerImpl.handleSuccessfulTask" which holds the // "TaskSchedulerImpl" lock until exiting. To avoid the SPARK-7655 issue, we should not @@ -919,11 +929,10 @@ private[spark] class TaskSetManager( var foundTasks = false val minFinishedForSpeculation = (SPECULATION_QUANTILE * numTasks).floor.toInt logDebug("Checking for speculative tasks: minFinished = " + minFinishedForSpeculation) + if (tasksSuccessful >= minFinishedForSpeculation && tasksSuccessful > 0) { val time = clock.getTimeMillis() - val durations = taskInfos.values.filter(_.successful).map(_.duration).toArray - Arrays.sort(durations) - val medianDuration = durations(min((0.5 * tasksSuccessful).round.toInt, durations.length - 1)) + var medianDuration = successfulTaskDurations.median val threshold = max(SPECULATION_MULTIPLIER * medianDuration, minTimeToSpeculation) // TODO: Threshold should also look at standard deviation of task durations and have a lower // bound based on that. diff --git a/core/src/main/scala/org/apache/spark/util/collection/MedianHeap.scala b/core/src/main/scala/org/apache/spark/util/collection/MedianHeap.scala new file mode 100644 index 000000000000..6e57c3c5bee8 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/collection/MedianHeap.scala @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection + +import scala.collection.mutable.PriorityQueue + +/** + * MedianHeap is designed to be used to quickly track the median of a group of numbers + * that may contain duplicates. Inserting a new number has O(log n) time complexity and + * determining the median has O(1) time complexity. + * The basic idea is to maintain two heaps: a smallerHalf and a largerHalf. The smallerHalf + * stores the smaller half of all numbers while the largerHalf stores the larger half. + * The sizes of two heaps need to be balanced each time when a new number is inserted so + * that their sizes will not be different by more than 1. Therefore each time when + * findMedian() is called we check if two heaps have the same size. If they do, we should + * return the average of the two top values of heaps. Otherwise we return the top of the + * heap which has one more element. + */ +private[spark] class MedianHeap(implicit val ord: Ordering[Double]) { + + /** + * Stores all the numbers less than the current median in a smallerHalf, + * i.e median is the maximum, at the root. + */ + private[this] var smallerHalf = PriorityQueue.empty[Double](ord) + + /** + * Stores all the numbers greater than the current median in a largerHalf, + * i.e median is the minimum, at the root. + */ + private[this] var largerHalf = PriorityQueue.empty[Double](ord.reverse) + + def isEmpty(): Boolean = { + smallerHalf.isEmpty && largerHalf.isEmpty + } + + def size(): Int = { + smallerHalf.size + largerHalf.size + } + + def insert(x: Double): Unit = { + // If both heaps are empty, we arbitrarily insert it into a heap, let's say, the largerHalf. + if (isEmpty) { + largerHalf.enqueue(x) + } else { + // If the number is larger than current median, it should be inserted into largerHalf, + // otherwise smallerHalf. + if (x > median) { + largerHalf.enqueue(x) + } else { + smallerHalf.enqueue(x) + } + } + rebalance() + } + + private[this] def rebalance(): Unit = { + if (largerHalf.size - smallerHalf.size > 1) { + smallerHalf.enqueue(largerHalf.dequeue()) + } + if (smallerHalf.size - largerHalf.size > 1) { + largerHalf.enqueue(smallerHalf.dequeue) + } + } + + def median: Double = { + if (isEmpty) { + throw new NoSuchElementException("MedianHeap is empty.") + } + if (largerHalf.size == smallerHalf.size) { + (largerHalf.head + smallerHalf.head) / 2.0 + } else if (largerHalf.size > smallerHalf.size) { + largerHalf.head + } else { + smallerHalf.head + } + } +} diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index f36bcd8504b0..064af381a76d 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -893,6 +893,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg val taskSet = FakeTask.createTaskSet(4) // Set the speculation multiplier to be 0 so speculative tasks are launched immediately sc.conf.set("spark.speculation.multiplier", "0.0") + sc.conf.set("spark.speculation", "true") val clock = new ManualClock() val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock = clock) val accumUpdatesByTask: Array[Seq[AccumulatorV2[_, _]]] = taskSet.tasks.map { task => @@ -948,6 +949,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg // Set the speculation multiplier to be 0 so speculative tasks are launched immediately sc.conf.set("spark.speculation.multiplier", "0.0") sc.conf.set("spark.speculation.quantile", "0.6") + sc.conf.set("spark.speculation", "true") val clock = new ManualClock() val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock = clock) val accumUpdatesByTask: Array[Seq[AccumulatorV2[_, _]]] = taskSet.tasks.map { task => diff --git a/core/src/test/scala/org/apache/spark/util/collection/MedianHeapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/MedianHeapSuite.scala new file mode 100644 index 000000000000..c2a3ee95f1c5 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/util/collection/MedianHeapSuite.scala @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection + +import java.util.NoSuchElementException + +import org.apache.spark.SparkFunSuite + +class MedianHeapSuite extends SparkFunSuite { + + test("If no numbers in MedianHeap, NoSuchElementException is thrown.") { + val medianHeap = new MedianHeap() + intercept[NoSuchElementException] { + medianHeap.median + } + } + + test("Median should be correct when size of MedianHeap is even") { + val array = Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9) + val medianHeap = new MedianHeap() + array.foreach(medianHeap.insert(_)) + assert(medianHeap.size() === 10) + assert(medianHeap.median === 4.5) + } + + test("Median should be correct when size of MedianHeap is odd") { + val array = Array(0, 1, 2, 3, 4, 5, 6, 7, 8) + val medianHeap = new MedianHeap() + array.foreach(medianHeap.insert(_)) + assert(medianHeap.size() === 9) + assert(medianHeap.median === 4) + } + + test("Median should be correct though there are duplicated numbers inside.") { + val array = Array(0, 0, 1, 1, 2, 3, 4) + val medianHeap = new MedianHeap() + array.foreach(medianHeap.insert(_)) + assert(medianHeap.size === 7) + assert(medianHeap.median === 1) + } + + test("Median should be correct when input data is skewed.") { + val medianHeap = new MedianHeap() + (0 until 10).foreach(_ => medianHeap.insert(5)) + assert(medianHeap.median === 5) + (0 until 100).foreach(_ => medianHeap.insert(10)) + assert(medianHeap.median === 10) + (0 until 1000).foreach(_ => medianHeap.insert(0)) + assert(medianHeap.median === 0) + } +} From 8e558041aa0c41ba9fb2ce242daaf6d6ed4d85b7 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Thu, 23 Mar 2017 23:30:40 -0700 Subject: [PATCH 118/512] [SPARK-19820][CORE] Add interface to kill tasks w/ a reason This commit adds a killTaskAttempt method to SparkContext, to allow users to kill tasks so that they can be re-scheduled elsewhere. This also refactors the task kill path to allow specifying a reason for the task kill. The reason is propagated opaquely through events, and will show up in the UI automatically as `(N killed: $reason)` and `TaskKilled: $reason`. Without this change, there is no way to provide the user feedback through the UI. Currently used reasons are "stage cancelled", "another attempt succeeded", and "killed via SparkContext.killTask". The user can also specify a custom reason through `SparkContext.killTask`. cc rxin In the stage overview UI the reasons are summarized: ![1](https://cloud.githubusercontent.com/assets/14922/23929209/a83b2862-08e1-11e7-8b3e-ae1967bbe2e5.png) Within the stage UI you can see individual task kill reasons: ![2](https://cloud.githubusercontent.com/assets/14922/23929200/9a798692-08e1-11e7-8697-72b27ad8a287.png) Existing tests, tried killing some stages in the UI and verified the messages are as expected. Author: Eric Liang Author: Eric Liang Closes #17166 from ericl/kill-reason. --- .../unsafe/sort/UnsafeInMemorySorter.java | 5 +- .../unsafe/sort/UnsafeSorterSpillReader.java | 5 +- .../apache/spark/InterruptibleIterator.scala | 7 +-- .../scala/org/apache/spark/SparkContext.scala | 18 +++++++ .../scala/org/apache/spark/TaskContext.scala | 10 ++++ .../org/apache/spark/TaskContextImpl.scala | 21 ++++++-- .../org/apache/spark/TaskEndReason.scala | 4 +- .../apache/spark/TaskKilledException.scala | 4 +- .../apache/spark/api/python/PythonRDD.scala | 2 +- .../CoarseGrainedExecutorBackend.scala | 4 +- .../org/apache/spark/executor/Executor.scala | 52 ++++++++++--------- .../apache/spark/scheduler/DAGScheduler.scala | 11 +++- .../spark/scheduler/SchedulerBackend.scala | 15 +++++- .../org/apache/spark/scheduler/Task.scala | 21 ++++---- .../spark/scheduler/TaskScheduler.scala | 7 +++ .../spark/scheduler/TaskSchedulerImpl.scala | 16 +++++- .../spark/scheduler/TaskSetManager.scala | 10 +++- .../cluster/CoarseGrainedClusterMessage.scala | 2 +- .../CoarseGrainedSchedulerBackend.scala | 10 ++-- .../local/LocalSchedulerBackend.scala | 11 ++-- .../scala/org/apache/spark/ui/UIUtils.scala | 7 ++- .../apache/spark/ui/jobs/AllJobsPage.scala | 4 +- .../apache/spark/ui/jobs/ExecutorTable.scala | 4 +- .../spark/ui/jobs/JobProgressListener.scala | 17 +++--- .../org/apache/spark/ui/jobs/StageTable.scala | 2 +- .../org/apache/spark/ui/jobs/UIData.scala | 6 +-- .../org/apache/spark/util/JsonProtocol.scala | 7 ++- .../org/apache/spark/SparkContextSuite.scala | 47 ++++++++++++++++- .../apache/spark/executor/ExecutorSuite.scala | 4 +- .../spark/scheduler/DAGSchedulerSuite.scala | 6 +++ .../ExternalClusterManagerSuite.scala | 2 + .../OutputCommitCoordinatorSuite.scala | 4 +- .../scheduler/SchedulerIntegrationSuite.scala | 3 +- .../spark/scheduler/TaskSetManagerSuite.scala | 17 +++--- .../org/apache/spark/ui/UIUtilsSuite.scala | 2 +- .../ui/jobs/JobProgressListenerSuite.scala | 5 +- .../apache/spark/util/JsonProtocolSuite.scala | 5 +- project/MimaExcludes.scala | 13 +++++ .../spark/executor/MesosExecutorBackend.scala | 3 +- .../MesosFineGrainedSchedulerBackend.scala | 3 +- .../execution/datasources/FileScanRDD.scala | 4 +- .../spark/streaming/ui/AllBatchesTable.scala | 2 +- .../apache/spark/streaming/ui/BatchPage.scala | 2 +- 43 files changed, 289 insertions(+), 115 deletions(-) diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java index f219c5605b64..c14c12664f5a 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java @@ -23,7 +23,6 @@ import org.apache.avro.reflect.Nullable; import org.apache.spark.TaskContext; -import org.apache.spark.TaskKilledException; import org.apache.spark.memory.MemoryConsumer; import org.apache.spark.memory.TaskMemoryManager; import org.apache.spark.unsafe.Platform; @@ -291,8 +290,8 @@ public void loadNext() { // to avoid performance overhead. This check is added here in `loadNext()` instead of in // `hasNext()` because it's technically possible for the caller to be relying on // `getNumRecords()` instead of `hasNext()` to know when to stop. - if (taskContext != null && taskContext.isInterrupted()) { - throw new TaskKilledException(); + if (taskContext != null) { + taskContext.killTaskIfInterrupted(); } // This pointer points to a 4-byte record length, followed by the record's bytes final long recordPointer = array.get(offset + position); diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java index b6323c624b7b..9521ab86a12d 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java @@ -24,7 +24,6 @@ import org.apache.spark.SparkEnv; import org.apache.spark.TaskContext; -import org.apache.spark.TaskKilledException; import org.apache.spark.io.NioBufferedFileInputStream; import org.apache.spark.serializer.SerializerManager; import org.apache.spark.storage.BlockId; @@ -102,8 +101,8 @@ public void loadNext() throws IOException { // to avoid performance overhead. This check is added here in `loadNext()` instead of in // `hasNext()` because it's technically possible for the caller to be relying on // `getNumRecords()` instead of `hasNext()` to know when to stop. - if (taskContext != null && taskContext.isInterrupted()) { - throw new TaskKilledException(); + if (taskContext != null) { + taskContext.killTaskIfInterrupted(); } recordLength = din.readInt(); keyPrefix = din.readLong(); diff --git a/core/src/main/scala/org/apache/spark/InterruptibleIterator.scala b/core/src/main/scala/org/apache/spark/InterruptibleIterator.scala index 5c262bcbddf7..7f2c0068174b 100644 --- a/core/src/main/scala/org/apache/spark/InterruptibleIterator.scala +++ b/core/src/main/scala/org/apache/spark/InterruptibleIterator.scala @@ -33,11 +33,8 @@ class InterruptibleIterator[+T](val context: TaskContext, val delegate: Iterator // is allowed. The assumption is that Thread.interrupted does not have a memory fence in read // (just a volatile field in C), while context.interrupted is a volatile in the JVM, which // introduces an expensive read fence. - if (context.isInterrupted) { - throw new TaskKilledException - } else { - delegate.hasNext - } + context.killTaskIfInterrupted() + delegate.hasNext } def next(): T = delegate.next() diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 0e36a30c933d..0225fd605607 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -2249,6 +2249,24 @@ class SparkContext(config: SparkConf) extends Logging { dagScheduler.cancelStage(stageId, None) } + /** + * Kill and reschedule the given task attempt. Task ids can be obtained from the Spark UI + * or through SparkListener.onTaskStart. + * + * @param taskId the task ID to kill. This id uniquely identifies the task attempt. + * @param interruptThread whether to interrupt the thread running the task. + * @param reason the reason for killing the task, which should be a short string. If a task + * is killed multiple times with different reasons, only one reason will be reported. + * + * @return Whether the task was successfully killed. + */ + def killTaskAttempt( + taskId: Long, + interruptThread: Boolean = true, + reason: String = "killed via SparkContext.killTaskAttempt"): Boolean = { + dagScheduler.killTaskAttempt(taskId, interruptThread, reason) + } + /** * Clean a closure to make it ready to serialized and send to tasks * (removes unreferenced variables in $outer's, updates REPL variables) diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala index 5acfce17593b..0b87cd503d4f 100644 --- a/core/src/main/scala/org/apache/spark/TaskContext.scala +++ b/core/src/main/scala/org/apache/spark/TaskContext.scala @@ -184,6 +184,16 @@ abstract class TaskContext extends Serializable { @DeveloperApi def getMetricsSources(sourceName: String): Seq[Source] + /** + * If the task is interrupted, throws TaskKilledException with the reason for the interrupt. + */ + private[spark] def killTaskIfInterrupted(): Unit + + /** + * If the task is interrupted, the reason this task was killed, otherwise None. + */ + private[spark] def getKillReason(): Option[String] + /** * Returns the manager for this task's managed memory. */ diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala index f346cf8d6580..8cd1d1c96aa0 100644 --- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala +++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala @@ -59,8 +59,8 @@ private[spark] class TaskContextImpl( /** List of callback functions to execute when the task fails. */ @transient private val onFailureCallbacks = new ArrayBuffer[TaskFailureListener] - // Whether the corresponding task has been killed. - @volatile private var interrupted: Boolean = false + // If defined, the corresponding task has been killed and this option contains the reason. + @volatile private var reasonIfKilled: Option[String] = None // Whether the task has completed. private var completed: Boolean = false @@ -140,8 +140,19 @@ private[spark] class TaskContextImpl( } /** Marks the task for interruption, i.e. cancellation. */ - private[spark] def markInterrupted(): Unit = { - interrupted = true + private[spark] def markInterrupted(reason: String): Unit = { + reasonIfKilled = Some(reason) + } + + private[spark] override def killTaskIfInterrupted(): Unit = { + val reason = reasonIfKilled + if (reason.isDefined) { + throw new TaskKilledException(reason.get) + } + } + + private[spark] override def getKillReason(): Option[String] = { + reasonIfKilled } @GuardedBy("this") @@ -149,7 +160,7 @@ private[spark] class TaskContextImpl( override def isRunningLocally(): Boolean = false - override def isInterrupted(): Boolean = interrupted + override def isInterrupted(): Boolean = reasonIfKilled.isDefined override def getLocalProperty(key: String): String = localProperties.getProperty(key) diff --git a/core/src/main/scala/org/apache/spark/TaskEndReason.scala b/core/src/main/scala/org/apache/spark/TaskEndReason.scala index 8c1b5f7bf0d9..a76283e33fa6 100644 --- a/core/src/main/scala/org/apache/spark/TaskEndReason.scala +++ b/core/src/main/scala/org/apache/spark/TaskEndReason.scala @@ -212,8 +212,8 @@ case object TaskResultLost extends TaskFailedReason { * Task was killed intentionally and needs to be rescheduled. */ @DeveloperApi -case object TaskKilled extends TaskFailedReason { - override def toErrorString: String = "TaskKilled (killed intentionally)" +case class TaskKilled(reason: String) extends TaskFailedReason { + override def toErrorString: String = s"TaskKilled ($reason)" override def countTowardsTaskFailures: Boolean = false } diff --git a/core/src/main/scala/org/apache/spark/TaskKilledException.scala b/core/src/main/scala/org/apache/spark/TaskKilledException.scala index ad487c4efb87..9dbf0d493be1 100644 --- a/core/src/main/scala/org/apache/spark/TaskKilledException.scala +++ b/core/src/main/scala/org/apache/spark/TaskKilledException.scala @@ -24,4 +24,6 @@ import org.apache.spark.annotation.DeveloperApi * Exception thrown when a task is explicitly killed (i.e., task failure is expected). */ @DeveloperApi -class TaskKilledException extends RuntimeException +class TaskKilledException(val reason: String) extends RuntimeException { + def this() = this("unknown reason") +} diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 04ae97ed3ccb..b0dd2fc187ba 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -215,7 +215,7 @@ private[spark] class PythonRunner( case e: Exception if context.isInterrupted => logDebug("Exception thrown after task interruption", e) - throw new TaskKilledException + throw new TaskKilledException(context.getKillReason().getOrElse("unknown reason")) case e: Exception if env.isStopped => logDebug("Exception thrown after context is stopped", e) diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index b376ecd301ea..ba0096d87456 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -97,11 +97,11 @@ private[spark] class CoarseGrainedExecutorBackend( executor.launchTask(this, taskDesc) } - case KillTask(taskId, _, interruptThread) => + case KillTask(taskId, _, interruptThread, reason) => if (executor == null) { exitExecutor(1, "Received KillTask command but executor was null") } else { - executor.killTask(taskId, interruptThread) + executor.killTask(taskId, interruptThread, reason) } case StopExecutor => diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 790c1ae94247..99b1608010dd 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -158,7 +158,7 @@ private[spark] class Executor( threadPool.execute(tr) } - def killTask(taskId: Long, interruptThread: Boolean): Unit = { + def killTask(taskId: Long, interruptThread: Boolean, reason: String): Unit = { val taskRunner = runningTasks.get(taskId) if (taskRunner != null) { if (taskReaperEnabled) { @@ -168,7 +168,8 @@ private[spark] class Executor( case Some(existingReaper) => interruptThread && !existingReaper.interruptThread } if (shouldCreateReaper) { - val taskReaper = new TaskReaper(taskRunner, interruptThread = interruptThread) + val taskReaper = new TaskReaper( + taskRunner, interruptThread = interruptThread, reason = reason) taskReaperForTask(taskId) = taskReaper Some(taskReaper) } else { @@ -178,7 +179,7 @@ private[spark] class Executor( // Execute the TaskReaper from outside of the synchronized block. maybeNewTaskReaper.foreach(taskReaperPool.execute) } else { - taskRunner.kill(interruptThread = interruptThread) + taskRunner.kill(interruptThread = interruptThread, reason = reason) } } } @@ -189,8 +190,9 @@ private[spark] class Executor( * tasks instead of taking the JVM down. * @param interruptThread whether to interrupt the task thread */ - def killAllTasks(interruptThread: Boolean) : Unit = { - runningTasks.keys().asScala.foreach(t => killTask(t, interruptThread = interruptThread)) + def killAllTasks(interruptThread: Boolean, reason: String) : Unit = { + runningTasks.keys().asScala.foreach(t => + killTask(t, interruptThread = interruptThread, reason = reason)) } def stop(): Unit = { @@ -217,8 +219,8 @@ private[spark] class Executor( val threadName = s"Executor task launch worker for task $taskId" private val taskName = taskDescription.name - /** Whether this task has been killed. */ - @volatile private var killed = false + /** If specified, this task has been killed and this option contains the reason. */ + @volatile private var reasonIfKilled: Option[String] = None @volatile private var threadId: Long = -1 @@ -239,13 +241,13 @@ private[spark] class Executor( */ @volatile var task: Task[Any] = _ - def kill(interruptThread: Boolean): Unit = { - logInfo(s"Executor is trying to kill $taskName (TID $taskId)") - killed = true + def kill(interruptThread: Boolean, reason: String): Unit = { + logInfo(s"Executor is trying to kill $taskName (TID $taskId), reason: $reason") + reasonIfKilled = Some(reason) if (task != null) { synchronized { if (!finished) { - task.kill(interruptThread) + task.kill(interruptThread, reason) } } } @@ -296,12 +298,13 @@ private[spark] class Executor( // If this task has been killed before we deserialized it, let's quit now. Otherwise, // continue executing the task. - if (killed) { + val killReason = reasonIfKilled + if (killReason.isDefined) { // Throw an exception rather than returning, because returning within a try{} block // causes a NonLocalReturnControl exception to be thrown. The NonLocalReturnControl // exception will be caught by the catch block, leading to an incorrect ExceptionFailure // for the task. - throw new TaskKilledException + throw new TaskKilledException(killReason.get) } logDebug("Task " + taskId + "'s epoch is " + task.epoch) @@ -358,9 +361,7 @@ private[spark] class Executor( } else 0L // If the task has been killed, let's fail it. - if (task.killed) { - throw new TaskKilledException - } + task.context.killTaskIfInterrupted() val resultSer = env.serializer.newInstance() val beforeSerialization = System.currentTimeMillis() @@ -426,15 +427,17 @@ private[spark] class Executor( setTaskFinishedAndClearInterruptStatus() execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason)) - case _: TaskKilledException => - logInfo(s"Executor killed $taskName (TID $taskId)") + case t: TaskKilledException => + logInfo(s"Executor killed $taskName (TID $taskId), reason: ${t.reason}") setTaskFinishedAndClearInterruptStatus() - execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled)) + execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled(t.reason))) - case _: InterruptedException if task.killed => - logInfo(s"Executor interrupted and killed $taskName (TID $taskId)") + case _: InterruptedException if task.reasonIfKilled.isDefined => + val killReason = task.reasonIfKilled.getOrElse("unknown reason") + logInfo(s"Executor interrupted and killed $taskName (TID $taskId), reason: $killReason") setTaskFinishedAndClearInterruptStatus() - execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled)) + execBackend.statusUpdate( + taskId, TaskState.KILLED, ser.serialize(TaskKilled(killReason))) case CausedBy(cDE: CommitDeniedException) => val reason = cDE.toTaskFailedReason @@ -512,7 +515,8 @@ private[spark] class Executor( */ private class TaskReaper( taskRunner: TaskRunner, - val interruptThread: Boolean) + val interruptThread: Boolean, + val reason: String) extends Runnable { private[this] val taskId: Long = taskRunner.taskId @@ -533,7 +537,7 @@ private[spark] class Executor( // Only attempt to kill the task once. If interruptThread = false then a second kill // attempt would be a no-op and if interruptThread = true then it may not be safe or // effective to interrupt multiple times: - taskRunner.kill(interruptThread = interruptThread) + taskRunner.kill(interruptThread = interruptThread, reason = reason) // Monitor the killed task until it exits. The synchronization logic here is complicated // because we don't want to synchronize on the taskRunner while possibly taking a thread // dump, but we also need to be careful to avoid races between checking whether the task diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index d944f268755d..09717316833a 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -738,6 +738,15 @@ class DAGScheduler( eventProcessLoop.post(StageCancelled(stageId, reason)) } + /** + * Kill a given task. It will be retried. + * + * @return Whether the task was successfully killed. + */ + def killTaskAttempt(taskId: Long, interruptThread: Boolean, reason: String): Boolean = { + taskScheduler.killTaskAttempt(taskId, interruptThread, reason) + } + /** * Resubmit any failed stages. Ordinarily called after a small amount of time has passed since * the last fetch failure. @@ -1353,7 +1362,7 @@ class DAGScheduler( case TaskResultLost => // Do nothing here; the TaskScheduler handles these failures and resubmits the task. - case _: ExecutorLostFailure | TaskKilled | UnknownReason => + case _: ExecutorLostFailure | _: TaskKilled | UnknownReason => // Unrecognized failure - also do nothing. If the task fails repeatedly, the TaskScheduler // will abort the job. } diff --git a/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala index 8801a761afae..22db3350abfa 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala @@ -30,8 +30,21 @@ private[spark] trait SchedulerBackend { def reviveOffers(): Unit def defaultParallelism(): Int - def killTask(taskId: Long, executorId: String, interruptThread: Boolean): Unit = + /** + * Requests that an executor kills a running task. + * + * @param taskId Id of the task. + * @param executorId Id of the executor the task is running on. + * @param interruptThread Whether the executor should interrupt the task thread. + * @param reason The reason for the task kill. + */ + def killTask( + taskId: Long, + executorId: String, + interruptThread: Boolean, + reason: String): Unit = throw new UnsupportedOperationException + def isReady(): Boolean = true /** diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index 70213722aae4..46ef23f316a6 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -89,8 +89,8 @@ private[spark] abstract class Task[T]( TaskContext.setTaskContext(context) taskThread = Thread.currentThread() - if (_killed) { - kill(interruptThread = false) + if (_reasonIfKilled != null) { + kill(interruptThread = false, _reasonIfKilled) } new CallerContext( @@ -158,17 +158,17 @@ private[spark] abstract class Task[T]( // The actual Thread on which the task is running, if any. Initialized in run(). @volatile @transient private var taskThread: Thread = _ - // A flag to indicate whether the task is killed. This is used in case context is not yet - // initialized when kill() is invoked. - @volatile @transient private var _killed = false + // If non-null, this task has been killed and the reason is as specified. This is used in case + // context is not yet initialized when kill() is invoked. + @volatile @transient private var _reasonIfKilled: String = null protected var _executorDeserializeTime: Long = 0 protected var _executorDeserializeCpuTime: Long = 0 /** - * Whether the task has been killed. + * If defined, this task has been killed and this option contains the reason. */ - def killed: Boolean = _killed + def reasonIfKilled: Option[String] = Option(_reasonIfKilled) /** * Returns the amount of time spent deserializing the RDD and function to be run. @@ -201,10 +201,11 @@ private[spark] abstract class Task[T]( * be called multiple times. * If interruptThread is true, we will also call Thread.interrupt() on the Task's executor thread. */ - def kill(interruptThread: Boolean) { - _killed = true + def kill(interruptThread: Boolean, reason: String) { + require(reason != null) + _reasonIfKilled = reason if (context != null) { - context.markInterrupted() + context.markInterrupted(reason) } if (interruptThread && taskThread != null) { taskThread.interrupt() diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala index cd13eebe74a9..3de7d1f7de22 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala @@ -54,6 +54,13 @@ private[spark] trait TaskScheduler { // Cancel a stage. def cancelTasks(stageId: Int, interruptThread: Boolean): Unit + /** + * Kills a task attempt. + * + * @return Whether the task was successfully killed. + */ + def killTaskAttempt(taskId: Long, interruptThread: Boolean, reason: String): Boolean + // Set the DAG scheduler for upcalls. This is guaranteed to be set before submitTasks is called. def setDAGScheduler(dagScheduler: DAGScheduler): Unit diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index d6225a08739d..07aea773fa63 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -241,7 +241,7 @@ private[spark] class TaskSchedulerImpl private[scheduler]( // simply abort the stage. tsm.runningTasksSet.foreach { tid => val execId = taskIdToExecutorId(tid) - backend.killTask(tid, execId, interruptThread) + backend.killTask(tid, execId, interruptThread, reason = "stage cancelled") } tsm.abort("Stage %s cancelled".format(stageId)) logInfo("Stage %d was cancelled".format(stageId)) @@ -249,6 +249,18 @@ private[spark] class TaskSchedulerImpl private[scheduler]( } } + override def killTaskAttempt(taskId: Long, interruptThread: Boolean, reason: String): Boolean = { + logInfo(s"Killing task $taskId: $reason") + val execId = taskIdToExecutorId.get(taskId) + if (execId.isDefined) { + backend.killTask(taskId, execId.get, interruptThread, reason) + true + } else { + logWarning(s"Could not kill task $taskId because no task with that ID was found.") + false + } + } + /** * Called to indicate that all task attempts (including speculated tasks) associated with the * given TaskSetManager have completed, so state associated with the TaskSetManager should be @@ -469,7 +481,7 @@ private[spark] class TaskSchedulerImpl private[scheduler]( taskState: TaskState, reason: TaskFailedReason): Unit = synchronized { taskSetManager.handleFailedTask(tid, taskState, reason) - if (!taskSetManager.isZombie && taskState != TaskState.KILLED) { + if (!taskSetManager.isZombie && !taskSetManager.someAttemptSucceeded(tid)) { // Need to revive offers again now that the task set manager state has been updated to // reflect failed tasks that need to be re-run. backend.reviveOffers() diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index f4a21bca79aa..a177aab5f95d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -101,6 +101,10 @@ private[spark] class TaskSetManager( override def runningTasks: Int = runningTasksSet.size + def someAttemptSucceeded(tid: Long): Boolean = { + successful(taskInfos(tid).index) + } + // True once no more tasks should be launched for this task set manager. TaskSetManagers enter // the zombie state once at least one attempt of each task has completed successfully, or if the // task set is aborted (for example, because it was killed). TaskSetManagers remain in the zombie @@ -722,7 +726,11 @@ private[spark] class TaskSetManager( logInfo(s"Killing attempt ${attemptInfo.attemptNumber} for task ${attemptInfo.id} " + s"in stage ${taskSet.id} (TID ${attemptInfo.taskId}) on ${attemptInfo.host} " + s"as the attempt ${info.attemptNumber} succeeded on ${info.host}") - sched.backend.killTask(attemptInfo.taskId, attemptInfo.executorId, true) + sched.backend.killTask( + attemptInfo.taskId, + attemptInfo.executorId, + interruptThread = true, + reason = "another attempt succeeded") } if (!successful(index)) { tasksSuccessful += 1 diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala index 2898cd7d17ca..6b49bd699a13 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala @@ -40,7 +40,7 @@ private[spark] object CoarseGrainedClusterMessages { // Driver to executors case class LaunchTask(data: SerializableBuffer) extends CoarseGrainedClusterMessage - case class KillTask(taskId: Long, executor: String, interruptThread: Boolean) + case class KillTask(taskId: Long, executor: String, interruptThread: Boolean, reason: String) extends CoarseGrainedClusterMessage case class KillExecutorsOnHost(host: String) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 7e2cfaccfc7b..4eedaaea6119 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -132,10 +132,11 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp case ReviveOffers => makeOffers() - case KillTask(taskId, executorId, interruptThread) => + case KillTask(taskId, executorId, interruptThread, reason) => executorDataMap.get(executorId) match { case Some(executorInfo) => - executorInfo.executorEndpoint.send(KillTask(taskId, executorId, interruptThread)) + executorInfo.executorEndpoint.send( + KillTask(taskId, executorId, interruptThread, reason)) case None => // Ignoring the task kill since the executor is not registered. logWarning(s"Attempted to kill task $taskId for unknown executor $executorId.") @@ -428,8 +429,9 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp driverEndpoint.send(ReviveOffers) } - override def killTask(taskId: Long, executorId: String, interruptThread: Boolean) { - driverEndpoint.send(KillTask(taskId, executorId, interruptThread)) + override def killTask( + taskId: Long, executorId: String, interruptThread: Boolean, reason: String) { + driverEndpoint.send(KillTask(taskId, executorId, interruptThread, reason)) } override def defaultParallelism(): Int = { diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalSchedulerBackend.scala index 625f998cd460..35509bc2f85b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalSchedulerBackend.scala @@ -34,7 +34,7 @@ private case class ReviveOffers() private case class StatusUpdate(taskId: Long, state: TaskState, serializedData: ByteBuffer) -private case class KillTask(taskId: Long, interruptThread: Boolean) +private case class KillTask(taskId: Long, interruptThread: Boolean, reason: String) private case class StopExecutor() @@ -70,8 +70,8 @@ private[spark] class LocalEndpoint( reviveOffers() } - case KillTask(taskId, interruptThread) => - executor.killTask(taskId, interruptThread) + case KillTask(taskId, interruptThread, reason) => + executor.killTask(taskId, interruptThread, reason) } override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { @@ -143,8 +143,9 @@ private[spark] class LocalSchedulerBackend( override def defaultParallelism(): Int = scheduler.conf.getInt("spark.default.parallelism", totalCores) - override def killTask(taskId: Long, executorId: String, interruptThread: Boolean) { - localEndpoint.send(KillTask(taskId, interruptThread)) + override def killTask( + taskId: Long, executorId: String, interruptThread: Boolean, reason: String) { + localEndpoint.send(KillTask(taskId, interruptThread, reason)) } override def statusUpdate(taskId: Long, state: TaskState, serializedData: ByteBuffer) { diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala index d161843dd223..e53d6907bc40 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala @@ -342,7 +342,7 @@ private[spark] object UIUtils extends Logging { completed: Int, failed: Int, skipped: Int, - killed: Int, + reasonToNumKilled: Map[String, Int], total: Int): Seq[Node] = { val completeWidth = "width: %s%%".format((completed.toDouble/total)*100) // started + completed can be > total when there are speculative tasks @@ -354,7 +354,10 @@ private[spark] object UIUtils extends Logging { {completed}/{total} { if (failed > 0) s"($failed failed)" } { if (skipped > 0) s"($skipped skipped)" } - { if (killed > 0) s"($killed killed)" } + { reasonToNumKilled.toSeq.sortBy(-_._2).map { + case (reason, count) => s"($count killed: $reason)" + } + }
      diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala index d217f558045f..18be0870746e 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala @@ -630,8 +630,8 @@ private[ui] class JobPagedTable( {UIUtils.makeProgressBar(started = job.numActiveTasks, completed = job.numCompletedTasks, - failed = job.numFailedTasks, skipped = job.numSkippedTasks, killed = job.numKilledTasks, - total = job.numTasks - job.numSkippedTasks)} + failed = job.numFailedTasks, skipped = job.numSkippedTasks, + reasonToNumKilled = job.reasonToNumKilled, total = job.numTasks - job.numSkippedTasks)} } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala index cd1b02addc78..52f41298a172 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala @@ -133,9 +133,9 @@ private[ui] class ExecutorTable(stageId: Int, stageAttemptId: Int, parent: Stage {executorIdToAddress.getOrElse(k, "CANNOT FIND ADDRESS")} {UIUtils.formatDuration(v.taskTime)} - {v.failedTasks + v.succeededTasks + v.killedTasks} + {v.failedTasks + v.succeededTasks + v.reasonToNumKilled.map(_._2).sum} {v.failedTasks} - {v.killedTasks} + {v.reasonToNumKilled.map(_._2).sum} {v.succeededTasks} {if (stageData.hasInput) { diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala index e87caff42643..1cf03e1541d1 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala @@ -371,8 +371,9 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { taskEnd.reason match { case Success => execSummary.succeededTasks += 1 - case TaskKilled => - execSummary.killedTasks += 1 + case kill: TaskKilled => + execSummary.reasonToNumKilled = execSummary.reasonToNumKilled.updated( + kill.reason, execSummary.reasonToNumKilled.getOrElse(kill.reason, 0) + 1) case _ => execSummary.failedTasks += 1 } @@ -385,9 +386,10 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { stageData.completedIndices.add(info.index) stageData.numCompleteTasks += 1 None - case TaskKilled => - stageData.numKilledTasks += 1 - Some(TaskKilled.toErrorString) + case kill: TaskKilled => + stageData.reasonToNumKilled = stageData.reasonToNumKilled.updated( + kill.reason, stageData.reasonToNumKilled.getOrElse(kill.reason, 0) + 1) + Some(kill.toErrorString) case e: ExceptionFailure => // Handle ExceptionFailure because we might have accumUpdates stageData.numFailedTasks += 1 Some(e.toErrorString) @@ -422,8 +424,9 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { taskEnd.reason match { case Success => jobData.numCompletedTasks += 1 - case TaskKilled => - jobData.numKilledTasks += 1 + case kill: TaskKilled => + jobData.reasonToNumKilled = jobData.reasonToNumKilled.updated( + kill.reason, jobData.reasonToNumKilled.getOrElse(kill.reason, 0) + 1) case _ => jobData.numFailedTasks += 1 } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala index e1fa9043b6a1..f4caad0f5871 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala @@ -300,7 +300,7 @@ private[ui] class StagePagedTable( {UIUtils.makeProgressBar(started = stageData.numActiveTasks, completed = stageData.completedIndices.size, failed = stageData.numFailedTasks, - skipped = 0, killed = stageData.numKilledTasks, total = info.numTasks)} + skipped = 0, reasonToNumKilled = stageData.reasonToNumKilled, total = info.numTasks)} {data.inputReadWithUnit} {data.outputWriteWithUnit} diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala b/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala index 073f7edfc2fe..ac1a74ad8029 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala @@ -32,7 +32,7 @@ private[spark] object UIData { var taskTime : Long = 0 var failedTasks : Int = 0 var succeededTasks : Int = 0 - var killedTasks : Int = 0 + var reasonToNumKilled : Map[String, Int] = Map.empty var inputBytes : Long = 0 var inputRecords : Long = 0 var outputBytes : Long = 0 @@ -64,7 +64,7 @@ private[spark] object UIData { var numCompletedTasks: Int = 0, var numSkippedTasks: Int = 0, var numFailedTasks: Int = 0, - var numKilledTasks: Int = 0, + var reasonToNumKilled: Map[String, Int] = Map.empty, /* Stages */ var numActiveStages: Int = 0, // This needs to be a set instead of a simple count to prevent double-counting of rerun stages: @@ -78,7 +78,7 @@ private[spark] object UIData { var numCompleteTasks: Int = _ var completedIndices = new OpenHashSet[Int]() var numFailedTasks: Int = _ - var numKilledTasks: Int = _ + var reasonToNumKilled: Map[String, Int] = Map.empty var executorRunTime: Long = _ var executorCpuTime: Long = _ diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index 4b4d2d10cbf8..2cb88919c8c8 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -390,6 +390,8 @@ private[spark] object JsonProtocol { ("Executor ID" -> executorId) ~ ("Exit Caused By App" -> exitCausedByApp) ~ ("Loss Reason" -> reason.map(_.toString)) + case taskKilled: TaskKilled => + ("Kill Reason" -> taskKilled.reason) case _ => Utils.emptyJson } ("Reason" -> reason) ~ json @@ -877,7 +879,10 @@ private[spark] object JsonProtocol { })) ExceptionFailure(className, description, stackTrace, fullStackTrace, None, accumUpdates) case `taskResultLost` => TaskResultLost - case `taskKilled` => TaskKilled + case `taskKilled` => + val killReason = Utils.jsonOption(json \ "Kill Reason") + .map(_.extract[String]).getOrElse("unknown reason") + TaskKilled(killReason) case `taskCommitDenied` => // Unfortunately, the `TaskCommitDenied` message was introduced in 1.3.0 but the JSON // de/serialization logic was not added until 1.5.1. To provide backward compatibility diff --git a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala index d08a162feda0..2c947556dfd3 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala @@ -34,7 +34,7 @@ import org.apache.hadoop.mapreduce.lib.input.{TextInputFormat => NewTextInputFor import org.scalatest.concurrent.Eventually import org.scalatest.Matchers._ -import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart, SparkListenerTaskStart} +import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart, SparkListenerTaskEnd, SparkListenerTaskStart} import org.apache.spark.util.Utils @@ -540,6 +540,48 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu } } + // Launches one task that will run forever. Once the SparkListener detects the task has + // started, kill and re-schedule it. The second run of the task will complete immediately. + // If this test times out, then the first version of the task wasn't killed successfully. + test("Killing tasks") { + sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local")) + + SparkContextSuite.isTaskStarted = false + SparkContextSuite.taskKilled = false + SparkContextSuite.taskSucceeded = false + + val listener = new SparkListener { + override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = { + eventually(timeout(10.seconds)) { + assert(SparkContextSuite.isTaskStarted) + } + if (!SparkContextSuite.taskKilled) { + SparkContextSuite.taskKilled = true + sc.killTaskAttempt(taskStart.taskInfo.taskId, true, "first attempt will hang") + } + } + override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = { + if (taskEnd.taskInfo.attemptNumber == 1 && taskEnd.reason == Success) { + SparkContextSuite.taskSucceeded = true + } + } + } + sc.addSparkListener(listener) + eventually(timeout(20.seconds)) { + sc.parallelize(1 to 1).foreach { x => + // first attempt will hang + if (!SparkContextSuite.isTaskStarted) { + SparkContextSuite.isTaskStarted = true + Thread.sleep(9999999) + } + // second attempt succeeds immediately + } + } + eventually(timeout(10.seconds)) { + assert(SparkContextSuite.taskSucceeded) + } + } + test("SPARK-19446: DebugFilesystem.assertNoOpenStreams should report " + "open streams to help debugging") { val fs = new DebugFilesystem() @@ -555,11 +597,12 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu assert(exc.getCause() != null) stream.close() } - } object SparkContextSuite { @volatile var cancelJob = false @volatile var cancelStage = false @volatile var isTaskStarted = false + @volatile var taskKilled = false + @volatile var taskSucceeded = false } diff --git a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala index 8150fff2d018..f47e574b4fc4 100644 --- a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala +++ b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala @@ -110,14 +110,14 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug } // we know the task will be started, but not yet deserialized, because of the latches we // use in mockExecutorBackend. - executor.killAllTasks(true) + executor.killAllTasks(true, "test") executorSuiteHelper.latch2.countDown() if (!executorSuiteHelper.latch3.await(5, TimeUnit.SECONDS)) { fail("executor did not send second status update in time") } // `testFailedReason` should be `TaskKilled`; `taskState` should be `KILLED` - assert(executorSuiteHelper.testFailedReason === TaskKilled) + assert(executorSuiteHelper.testFailedReason === TaskKilled("test")) assert(executorSuiteHelper.taskState === TaskState.KILLED) } finally { diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index a9389003d5db..a10941b579fe 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -126,6 +126,8 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou override def cancelTasks(stageId: Int, interruptThread: Boolean) { cancelledStages += stageId } + override def killTaskAttempt( + taskId: Long, interruptThread: Boolean, reason: String): Boolean = false override def setDAGScheduler(dagScheduler: DAGScheduler) = {} override def defaultParallelism() = 2 override def executorLost(executorId: String, reason: ExecutorLossReason): Unit = {} @@ -552,6 +554,10 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou override def cancelTasks(stageId: Int, interruptThread: Boolean) { throw new UnsupportedOperationException } + override def killTaskAttempt( + taskId: Long, interruptThread: Boolean, reason: String): Boolean = { + throw new UnsupportedOperationException + } override def setDAGScheduler(dagScheduler: DAGScheduler): Unit = {} override def defaultParallelism(): Int = 2 override def executorHeartbeatReceived( diff --git a/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala index 37c124a726be..ba56af8215cd 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala @@ -79,6 +79,8 @@ private class DummyTaskScheduler extends TaskScheduler { override def stop(): Unit = {} override def submitTasks(taskSet: TaskSet): Unit = {} override def cancelTasks(stageId: Int, interruptThread: Boolean): Unit = {} + override def killTaskAttempt( + taskId: Long, interruptThread: Boolean, reason: String): Boolean = false override def setDAGScheduler(dagScheduler: DAGScheduler): Unit = {} override def defaultParallelism(): Int = 2 override def executorLost(executorId: String, reason: ExecutorLossReason): Unit = {} diff --git a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala index 38b9d40329d4..e51e6a0d3ff6 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala @@ -176,13 +176,13 @@ class OutputCommitCoordinatorSuite extends SparkFunSuite with BeforeAndAfter { assert(!outputCommitCoordinator.canCommit(stage, partition, nonAuthorizedCommitter)) // The non-authorized committer fails outputCommitCoordinator.taskCompleted( - stage, partition, attemptNumber = nonAuthorizedCommitter, reason = TaskKilled) + stage, partition, attemptNumber = nonAuthorizedCommitter, reason = TaskKilled("test")) // New tasks should still not be able to commit because the authorized committer has not failed assert( !outputCommitCoordinator.canCommit(stage, partition, nonAuthorizedCommitter + 1)) // The authorized committer now fails, clearing the lock outputCommitCoordinator.taskCompleted( - stage, partition, attemptNumber = authorizedCommitter, reason = TaskKilled) + stage, partition, attemptNumber = authorizedCommitter, reason = TaskKilled("test")) // A new task should now be allowed to become the authorized committer assert( outputCommitCoordinator.canCommit(stage, partition, nonAuthorizedCommitter + 2)) diff --git a/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala index 398ac3d6202d..8103983c4392 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala @@ -410,7 +410,8 @@ private[spark] abstract class MockBackend( } } - override def killTask(taskId: Long, executorId: String, interruptThread: Boolean): Unit = { + override def killTask( + taskId: Long, executorId: String, interruptThread: Boolean, reason: String): Unit = { // We have to implement this b/c of SPARK-15385. // Its OK for this to be a no-op, because even if a backend does implement killTask, // it really can only be "best-effort" in any case, and the scheduler should be robust to that. diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index 064af381a76d..132caef0978f 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -677,7 +677,11 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg val sched = new FakeTaskScheduler(sc, ("execA", "host1"), ("execB", "host2")) sched.initialize(new FakeSchedulerBackend() { - override def killTask(taskId: Long, executorId: String, interruptThread: Boolean): Unit = {} + override def killTask( + taskId: Long, + executorId: String, + interruptThread: Boolean, + reason: String): Unit = {} }) // Keep track of the number of tasks that are resubmitted, @@ -935,7 +939,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg // Complete the speculative attempt for the running task manager.handleSuccessfulTask(4, createTaskResult(3, accumUpdatesByTask(3))) // Verify that it kills other running attempt - verify(sched.backend).killTask(3, "exec2", true) + verify(sched.backend).killTask(3, "exec2", true, "another attempt succeeded") // Because the SchedulerBackend was a mock, the 2nd copy of the task won't actually be // killed, so the FakeTaskScheduler is only told about the successful completion // of the speculated task. @@ -1023,14 +1027,14 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg manager.handleSuccessfulTask(speculativeTask.taskId, createTaskResult(3, accumUpdatesByTask(3))) // Verify that it kills other running attempt val origTask = originalTasks(speculativeTask.index) - verify(sched.backend).killTask(origTask.taskId, "exec2", true) + verify(sched.backend).killTask(origTask.taskId, "exec2", true, "another attempt succeeded") // Because the SchedulerBackend was a mock, the 2nd copy of the task won't actually be // killed, so the FakeTaskScheduler is only told about the successful completion // of the speculated task. assert(sched.endedTasks(3) === Success) // also because the scheduler is a mock, our manager isn't notified about the task killed event, // so we do that manually - manager.handleFailedTask(origTask.taskId, TaskState.KILLED, TaskKilled) + manager.handleFailedTask(origTask.taskId, TaskState.KILLED, TaskKilled("test")) // this task has "failed" 4 times, but one of them doesn't count, so keep running the stage assert(manager.tasksSuccessful === 4) assert(!manager.isZombie) @@ -1047,7 +1051,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg createTaskResult(3, accumUpdatesByTask(3))) // Verify that it kills other running attempt val origTask2 = originalTasks(speculativeTask2.index) - verify(sched.backend).killTask(origTask2.taskId, "exec2", true) + verify(sched.backend).killTask(origTask2.taskId, "exec2", true, "another attempt succeeded") assert(manager.tasksSuccessful === 5) assert(manager.isZombie) } @@ -1102,8 +1106,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg ExecutorLostFailure(taskDescs(1).executorId, exitCausedByApp = false, reason = None)) tsmSpy.handleFailedTask(taskDescs(2).taskId, TaskState.FAILED, TaskCommitDenied(0, 2, 0)) - tsmSpy.handleFailedTask(taskDescs(3).taskId, TaskState.KILLED, - TaskKilled) + tsmSpy.handleFailedTask(taskDescs(3).taskId, TaskState.KILLED, TaskKilled("test")) // Make sure that the blacklist ignored all of the task failures above, since they aren't // the fault of the executor where the task was running. diff --git a/core/src/test/scala/org/apache/spark/ui/UIUtilsSuite.scala b/core/src/test/scala/org/apache/spark/ui/UIUtilsSuite.scala index 6335d905c0fb..c770fd5da76f 100644 --- a/core/src/test/scala/org/apache/spark/ui/UIUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/UIUtilsSuite.scala @@ -110,7 +110,7 @@ class UIUtilsSuite extends SparkFunSuite { } test("SPARK-11906: Progress bar should not overflow because of speculative tasks") { - val generated = makeProgressBar(2, 3, 0, 0, 0, 4).head.child.filter(_.label == "div") + val generated = makeProgressBar(2, 3, 0, 0, Map.empty, 4).head.child.filter(_.label == "div") val expected = Seq(
      ,
      diff --git a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala index e3127da9a6b2..93964a2d5674 100644 --- a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala @@ -274,8 +274,9 @@ class JobProgressListenerSuite extends SparkFunSuite with LocalSparkContext with // Make sure killed tasks are accounted for correctly. listener.onTaskEnd( - SparkListenerTaskEnd(task.stageId, 0, taskType, TaskKilled, taskInfo, metrics)) - assert(listener.stageIdToData((task.stageId, 0)).numKilledTasks === 1) + SparkListenerTaskEnd( + task.stageId, 0, taskType, TaskKilled("test"), taskInfo, metrics)) + assert(listener.stageIdToData((task.stageId, 0)).reasonToNumKilled === Map("test" -> 1)) // Make sure we count success as success. listener.onTaskEnd( diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index 9f76c74bce89..a64dbeae4729 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -164,7 +164,7 @@ class JsonProtocolSuite extends SparkFunSuite { testTaskEndReason(fetchMetadataFailed) testTaskEndReason(exceptionFailure) testTaskEndReason(TaskResultLost) - testTaskEndReason(TaskKilled) + testTaskEndReason(TaskKilled("test")) testTaskEndReason(TaskCommitDenied(2, 3, 4)) testTaskEndReason(ExecutorLostFailure("100", true, Some("Induced failure"))) testTaskEndReason(UnknownReason) @@ -676,7 +676,8 @@ private[spark] object JsonProtocolSuite extends Assertions { assert(r1.fullStackTrace === r2.fullStackTrace) assertSeqEquals[AccumulableInfo](r1.accumUpdates, r2.accumUpdates, (a, b) => a.equals(b)) case (TaskResultLost, TaskResultLost) => - case (TaskKilled, TaskKilled) => + case (r1: TaskKilled, r2: TaskKilled) => + assert(r1.reason == r2.reason) case (TaskCommitDenied(jobId1, partitionId1, attemptNumber1), TaskCommitDenied(jobId2, partitionId2, attemptNumber2)) => assert(jobId1 === jobId2) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 9925a8ba7266..8ce9367c9b44 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -66,6 +66,19 @@ object MimaExcludes { // [SPARK-17161] Removing Python-friendly constructors not needed ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.OneVsRestModel.this"), + // [SPARK-19820] Allow reason to be specified to task kill + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.TaskKilled$"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.TaskKilled.productElement"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.TaskKilled.productArity"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.TaskKilled.canEqual"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.TaskKilled.productIterator"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.TaskKilled.countTowardsTaskFailures"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.TaskKilled.productPrefix"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.TaskKilled.toErrorString"), + ProblemFilters.exclude[FinalMethodProblem]("org.apache.spark.TaskKilled.toString"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.TaskContext.killTaskIfInterrupted"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.TaskContext.getKillReason"), + // [SPARK-19876] Add one time trigger, and improve Trigger APIs ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.sql.streaming.Trigger"), ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.sql.streaming.ProcessingTime") diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala index b25253978258..a086ec7ea2da 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala @@ -104,7 +104,8 @@ private[spark] class MesosExecutorBackend logError("Received KillTask but executor was null") } else { // TODO: Determine the 'interruptOnCancel' property set for the given job. - executor.killTask(t.getValue.toLong, interruptThread = false) + executor.killTask( + t.getValue.toLong, interruptThread = false, reason = "killed by mesos") } } diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala index f198f8893b3d..735c879c63c5 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala @@ -428,7 +428,8 @@ private[spark] class MesosFineGrainedSchedulerBackend( recordSlaveLost(d, slaveId, ExecutorExited(status, exitCausedByApp = true)) } - override def killTask(taskId: Long, executorId: String, interruptThread: Boolean): Unit = { + override def killTask( + taskId: Long, executorId: String, interruptThread: Boolean, reason: String): Unit = { schedulerDriver.killTask( TaskID.newBuilder() .setValue(taskId.toString).build() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala index a89d172a911a..9df20731c71d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala @@ -101,9 +101,7 @@ class FileScanRDD( // Kill the task in case it has been marked as killed. This logic is from // InterruptibleIterator, but we inline it here instead of wrapping the iterator in order // to avoid performance overhead. - if (context.isInterrupted()) { - throw new TaskKilledException - } + context.killTaskIfInterrupted() (currentIterator != null && currentIterator.hasNext) || nextIterator() } def next(): Object = { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/AllBatchesTable.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/AllBatchesTable.scala index 1352ca1c4c95..70b4bb466c46 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/AllBatchesTable.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/AllBatchesTable.scala @@ -97,7 +97,7 @@ private[ui] abstract class BatchTableBase(tableId: String, batchInterval: Long) completed = batch.numCompletedOutputOp, failed = batch.numFailedOutputOp, skipped = 0, - killed = 0, + reasonToNumKilled = Map.empty, total = batch.outputOperations.size) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala index 1a87fc790f91..f55af6a5cc35 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala @@ -146,7 +146,7 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { completed = sparkJob.numCompletedTasks, failed = sparkJob.numFailedTasks, skipped = sparkJob.numSkippedTasks, - killed = sparkJob.numKilledTasks, + reasonToNumKilled = sparkJob.reasonToNumKilled, total = sparkJob.numTasks - sparkJob.numSkippedTasks) } From 344f38b04b271b5f3ec2748b34db4e52d54da1bc Mon Sep 17 00:00:00 2001 From: Xiao Li Date: Fri, 24 Mar 2017 14:42:33 +0800 Subject: [PATCH 119/512] [SPARK-19970][SQL][FOLLOW-UP] Table owner should be USER instead of PRINCIPAL in kerberized clusters #17311 ### What changes were proposed in this pull request? This is a follow-up for the PR: https://github.com/apache/spark/pull/17311 - For safety, use `sessionState` to get the user name, instead of calling `SessionState.get()` in the function `toHiveTable`. - Passing `user names` instead of `conf` when calling `toHiveTable`. ### How was this patch tested? N/A Author: Xiao Li Closes #17405 from gatorsmile/user. --- .../sql/hive/client/HiveClientImpl.scala | 26 +++++++++---------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala index 13edcd051768..56ccac32a8d8 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala @@ -207,6 +207,8 @@ private[hive] class HiveClientImpl( /** Returns the configuration for the current session. */ def conf: HiveConf = state.getConf + private val userName = state.getAuthenticator.getUserName + override def getConf(key: String, defaultValue: String): String = { conf.get(key, defaultValue) } @@ -413,7 +415,7 @@ private[hive] class HiveClientImpl( createTime = h.getTTable.getCreateTime.toLong * 1000, lastAccessTime = h.getLastAccessTime.toLong * 1000, storage = CatalogStorageFormat( - locationUri = shim.getDataLocation(h).map(CatalogUtils.stringToURI(_)), + locationUri = shim.getDataLocation(h).map(CatalogUtils.stringToURI), // To avoid ClassNotFound exception, we try our best to not get the format class, but get // the class name directly. However, for non-native tables, there is no interface to get // the format class name, so we may still throw ClassNotFound in this case. @@ -441,7 +443,7 @@ private[hive] class HiveClientImpl( } override def createTable(table: CatalogTable, ignoreIfExists: Boolean): Unit = withHiveState { - client.createTable(toHiveTable(table, Some(conf)), ignoreIfExists) + client.createTable(toHiveTable(table, Some(userName)), ignoreIfExists) } override def dropTable( @@ -453,7 +455,7 @@ private[hive] class HiveClientImpl( } override def alterTable(tableName: String, table: CatalogTable): Unit = withHiveState { - val hiveTable = toHiveTable(table, Some(conf)) + val hiveTable = toHiveTable(table, Some(userName)) // Do not use `table.qualifiedName` here because this may be a rename val qualifiedTableName = s"${table.database}.$tableName" shim.alterTable(client, qualifiedTableName, hiveTable) @@ -522,7 +524,7 @@ private[hive] class HiveClientImpl( newSpecs: Seq[TablePartitionSpec]): Unit = withHiveState { require(specs.size == newSpecs.size, "number of old and new partition specs differ") val catalogTable = getTable(db, table) - val hiveTable = toHiveTable(catalogTable, Some(conf)) + val hiveTable = toHiveTable(catalogTable, Some(userName)) specs.zip(newSpecs).foreach { case (oldSpec, newSpec) => val hivePart = getPartitionOption(catalogTable, oldSpec) .map { p => toHivePartition(p.copy(spec = newSpec), hiveTable) } @@ -535,7 +537,7 @@ private[hive] class HiveClientImpl( db: String, table: String, newParts: Seq[CatalogTablePartition]): Unit = withHiveState { - val hiveTable = toHiveTable(getTable(db, table), Some(conf)) + val hiveTable = toHiveTable(getTable(db, table), Some(userName)) shim.alterPartitions(client, table, newParts.map { p => toHivePartition(p, hiveTable) }.asJava) } @@ -563,7 +565,7 @@ private[hive] class HiveClientImpl( override def getPartitionOption( table: CatalogTable, spec: TablePartitionSpec): Option[CatalogTablePartition] = withHiveState { - val hiveTable = toHiveTable(table, Some(conf)) + val hiveTable = toHiveTable(table, Some(userName)) val hivePartition = client.getPartition(hiveTable, spec.asJava, false) Option(hivePartition).map(fromHivePartition) } @@ -575,7 +577,7 @@ private[hive] class HiveClientImpl( override def getPartitions( table: CatalogTable, spec: Option[TablePartitionSpec]): Seq[CatalogTablePartition] = withHiveState { - val hiveTable = toHiveTable(table, Some(conf)) + val hiveTable = toHiveTable(table, Some(userName)) val parts = spec match { case None => shim.getAllPartitions(client, hiveTable).map(fromHivePartition) case Some(s) => @@ -589,7 +591,7 @@ private[hive] class HiveClientImpl( override def getPartitionsByFilter( table: CatalogTable, predicates: Seq[Expression]): Seq[CatalogTablePartition] = withHiveState { - val hiveTable = toHiveTable(table, Some(conf)) + val hiveTable = toHiveTable(table, Some(userName)) val parts = shim.getPartitionsByFilter(client, hiveTable, predicates).map(fromHivePartition) HiveCatalogMetrics.incrementFetchedPartitions(parts.length) parts @@ -817,9 +819,7 @@ private[hive] object HiveClientImpl { /** * Converts the native table metadata representation format CatalogTable to Hive's Table. */ - def toHiveTable( - table: CatalogTable, - conf: Option[HiveConf] = None): HiveTable = { + def toHiveTable(table: CatalogTable, userName: Option[String] = None): HiveTable = { val hiveTable = new HiveTable(table.database, table.identifier.table) // For EXTERNAL_TABLE, we also need to set EXTERNAL field in the table properties. // Otherwise, Hive metastore will change the table to a MANAGED_TABLE. @@ -851,10 +851,10 @@ private[hive] object HiveClientImpl { hiveTable.setFields(schema.asJava) } hiveTable.setPartCols(partCols.asJava) - conf.foreach { _ => hiveTable.setOwner(SessionState.get().getAuthenticator().getUserName()) } + userName.foreach(hiveTable.setOwner) hiveTable.setCreateTime((table.createTime / 1000).toInt) hiveTable.setLastAccessTime((table.lastAccessTime / 1000).toInt) - table.storage.locationUri.map(CatalogUtils.URIToString(_)).foreach { loc => + table.storage.locationUri.map(CatalogUtils.URIToString).foreach { loc => hiveTable.getTTable.getSd.setLocation(loc)} table.storage.inputFormat.map(toInputFormat).foreach(hiveTable.setInputFormatClass) table.storage.outputFormat.map(toOutputFormat).foreach(hiveTable.setOutputFormatClass) From d9f4ce6943c16a7e29f98e57c33acbfc0379b54d Mon Sep 17 00:00:00 2001 From: Nick Pentreath Date: Fri, 24 Mar 2017 08:01:15 -0700 Subject: [PATCH 120/512] [SPARK-15040][ML][PYSPARK] Add Imputer to PySpark Add Python wrapper for `Imputer` feature transformer. ## How was this patch tested? New doc tests and tweak to PySpark ML `tests.py` Author: Nick Pentreath Closes #17316 from MLnick/SPARK-15040-pyspark-imputer. --- .../org/apache/spark/ml/feature/Imputer.scala | 10 +- python/pyspark/ml/feature.py | 160 ++++++++++++++++++ python/pyspark/ml/tests.py | 10 ++ 3 files changed, 175 insertions(+), 5 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala index b1a802ee13fc..ec4c6ad75ee2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala @@ -93,12 +93,12 @@ private[feature] trait ImputerParams extends Params with HasInputCols { /** * :: Experimental :: * Imputation estimator for completing missing values, either using the mean or the median - * of the column in which the missing values are located. The input column should be of - * DoubleType or FloatType. Currently Imputer does not support categorical features yet + * of the columns in which the missing values are located. The input columns should be of + * DoubleType or FloatType. Currently Imputer does not support categorical features * (SPARK-15041) and possibly creates incorrect values for a categorical feature. * * Note that the mean/median value is computed after filtering out missing values. - * All Null values in the input column are treated as missing, and so are also imputed. For + * All Null values in the input columns are treated as missing, and so are also imputed. For * computing median, DataFrameStatFunctions.approxQuantile is used with a relative error of 0.001. */ @Experimental @@ -176,8 +176,8 @@ object Imputer extends DefaultParamsReadable[Imputer] { * :: Experimental :: * Model fitted by [[Imputer]]. * - * @param surrogateDF a DataFrame contains inputCols and their corresponding surrogates, which are - * used to replace the missing values in the input DataFrame. + * @param surrogateDF a DataFrame containing inputCols and their corresponding surrogates, + * which are used to replace the missing values in the input DataFrame. */ @Experimental class ImputerModel private[ml]( diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 92f8549e9cb9..8d25f5b3a771 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -36,6 +36,7 @@ 'ElementwiseProduct', 'HashingTF', 'IDF', 'IDFModel', + 'Imputer', 'ImputerModel', 'IndexToString', 'MaxAbsScaler', 'MaxAbsScalerModel', 'MinHashLSH', 'MinHashLSHModel', @@ -870,6 +871,165 @@ def idf(self): return self._call_java("idf") +@inherit_doc +class Imputer(JavaEstimator, HasInputCols, JavaMLReadable, JavaMLWritable): + """ + .. note:: Experimental + + Imputation estimator for completing missing values, either using the mean or the median + of the columns in which the missing values are located. The input columns should be of + DoubleType or FloatType. Currently Imputer does not support categorical features and + possibly creates incorrect values for a categorical feature. + + Note that the mean/median value is computed after filtering out missing values. + All Null values in the input columns are treated as missing, and so are also imputed. For + computing median, :py:meth:`pyspark.sql.DataFrame.approxQuantile` is used with a + relative error of `0.001`. + + >>> df = spark.createDataFrame([(1.0, float("nan")), (2.0, float("nan")), (float("nan"), 3.0), + ... (4.0, 4.0), (5.0, 5.0)], ["a", "b"]) + >>> imputer = Imputer(inputCols=["a", "b"], outputCols=["out_a", "out_b"]) + >>> model = imputer.fit(df) + >>> model.surrogateDF.show() + +---+---+ + | a| b| + +---+---+ + |3.0|4.0| + +---+---+ + ... + >>> model.transform(df).show() + +---+---+-----+-----+ + | a| b|out_a|out_b| + +---+---+-----+-----+ + |1.0|NaN| 1.0| 4.0| + |2.0|NaN| 2.0| 4.0| + |NaN|3.0| 3.0| 3.0| + ... + >>> imputer.setStrategy("median").setMissingValue(1.0).fit(df).transform(df).show() + +---+---+-----+-----+ + | a| b|out_a|out_b| + +---+---+-----+-----+ + |1.0|NaN| 4.0| NaN| + ... + >>> imputerPath = temp_path + "/imputer" + >>> imputer.save(imputerPath) + >>> loadedImputer = Imputer.load(imputerPath) + >>> loadedImputer.getStrategy() == imputer.getStrategy() + True + >>> loadedImputer.getMissingValue() + 1.0 + >>> modelPath = temp_path + "/imputer-model" + >>> model.save(modelPath) + >>> loadedModel = ImputerModel.load(modelPath) + >>> loadedModel.transform(df).head().out_a == model.transform(df).head().out_a + True + + .. versionadded:: 2.2.0 + """ + + outputCols = Param(Params._dummy(), "outputCols", + "output column names.", typeConverter=TypeConverters.toListString) + + strategy = Param(Params._dummy(), "strategy", + "strategy for imputation. If mean, then replace missing values using the mean " + "value of the feature. If median, then replace missing values using the " + "median value of the feature.", + typeConverter=TypeConverters.toString) + + missingValue = Param(Params._dummy(), "missingValue", + "The placeholder for the missing values. All occurrences of missingValue " + "will be imputed.", typeConverter=TypeConverters.toFloat) + + @keyword_only + def __init__(self, strategy="mean", missingValue=float("nan"), inputCols=None, + outputCols=None): + """ + __init__(self, strategy="mean", missingValue=float("nan"), inputCols=None, \ + outputCols=None): + """ + super(Imputer, self).__init__() + self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.Imputer", self.uid) + self._setDefault(strategy="mean", missingValue=float("nan")) + kwargs = self._input_kwargs + self.setParams(**kwargs) + + @keyword_only + @since("2.2.0") + def setParams(self, strategy="mean", missingValue=float("nan"), inputCols=None, + outputCols=None): + """ + setParams(self, strategy="mean", missingValue=float("nan"), inputCols=None, \ + outputCols=None) + Sets params for this Imputer. + """ + kwargs = self._input_kwargs + return self._set(**kwargs) + + @since("2.2.0") + def setOutputCols(self, value): + """ + Sets the value of :py:attr:`outputCols`. + """ + return self._set(outputCols=value) + + @since("2.2.0") + def getOutputCols(self): + """ + Gets the value of :py:attr:`outputCols` or its default value. + """ + return self.getOrDefault(self.outputCols) + + @since("2.2.0") + def setStrategy(self, value): + """ + Sets the value of :py:attr:`strategy`. + """ + return self._set(strategy=value) + + @since("2.2.0") + def getStrategy(self): + """ + Gets the value of :py:attr:`strategy` or its default value. + """ + return self.getOrDefault(self.strategy) + + @since("2.2.0") + def setMissingValue(self, value): + """ + Sets the value of :py:attr:`missingValue`. + """ + return self._set(missingValue=value) + + @since("2.2.0") + def getMissingValue(self): + """ + Gets the value of :py:attr:`missingValue` or its default value. + """ + return self.getOrDefault(self.missingValue) + + def _create_model(self, java_model): + return ImputerModel(java_model) + + +class ImputerModel(JavaModel, JavaMLReadable, JavaMLWritable): + """ + .. note:: Experimental + + Model fitted by :py:class:`Imputer`. + + .. versionadded:: 2.2.0 + """ + + @property + @since("2.2.0") + def surrogateDF(self): + """ + Returns a DataFrame containing inputCols and their corresponding surrogates, + which are used to replace the missing values in the input DataFrame. + """ + return self._call_java("surrogateDF") + + @inherit_doc class MaxAbsScaler(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): """ diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index f052f5bb770c..cc559db58720 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -1273,6 +1273,7 @@ class DefaultValuesTests(PySparkTestCase): """ def check_params(self, py_stage): + import pyspark.ml.feature if not hasattr(py_stage, "_to_java"): return java_stage = py_stage._to_java() @@ -1292,6 +1293,15 @@ def check_params(self, py_stage): _java2py(self.sc, java_stage.clear(java_param).getOrDefault(java_param)) py_stage._clear(p) py_default = py_stage.getOrDefault(p) + if isinstance(py_stage, pyspark.ml.feature.Imputer) and p.name == "missingValue": + # SPARK-15040 - default value for Imputer param 'missingValue' is NaN, + # and NaN != NaN, so handle it specially here + import math + self.assertTrue(math.isnan(java_default) and math.isnan(py_default), + "Java default %s and python default %s are not both NaN for " + "param %s for Params %s" + % (str(java_default), str(py_default), p.name, str(py_stage))) + return self.assertEqual(java_default, py_default, "Java default %s != python default %s of param %s for Params %s" % (str(java_default), str(py_default), p.name, str(py_stage))) From 9299d071f95798e33b18c08d3c75bb26f88b266b Mon Sep 17 00:00:00 2001 From: Jacek Laskowski Date: Fri, 24 Mar 2017 09:56:05 -0700 Subject: [PATCH 121/512] [SQL][MINOR] Fix for typo in Analyzer ## What changes were proposed in this pull request? Fix for typo in Analyzer ## How was this patch tested? local build Author: Jacek Laskowski Closes #17409 from jaceklaskowski/analyzer-typo. --- .../scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 036ed060d9ef..1b3a53c6359e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -2502,7 +2502,7 @@ object TimeWindowing extends Rule[LogicalPlan] { substitutedPlan.withNewChildren(expandedPlan :: Nil) } else if (windowExpressions.size > 1) { p.failAnalysis("Multiple time window expressions would result in a cartesian product " + - "of rows, therefore they are not currently not supported.") + "of rows, therefore they are currently not supported.") } else { p // Return unchanged. Analyzer will throw exception later } From 707e501832fa7adde0a884c528a7352983d83520 Mon Sep 17 00:00:00 2001 From: Adam Budde Date: Fri, 24 Mar 2017 12:40:29 -0700 Subject: [PATCH 122/512] [SPARK-19911][STREAMING] Add builder interface for Kinesis DStreams ## What changes were proposed in this pull request? - Add new KinesisDStream.scala containing KinesisDStream.Builder class - Add KinesisDStreamBuilderSuite test suite - Make KinesisInputDStream ctor args package private for testing - Add JavaKinesisDStreamBuilderSuite test suite - Add args to KinesisInputDStream and KinesisReceiver for optional service-specific auth (Kinesis, DynamoDB and CloudWatch) ## How was this patch tested? Added ```KinesisDStreamBuilderSuite``` to verify builder class works as expected Author: Adam Budde Closes #17250 from budde/KinesisStreamBuilder. --- .../kinesis/KinesisBackedBlockRDD.scala | 6 +- .../kinesis/KinesisInputDStream.scala | 259 +++++++++++++++++- .../streaming/kinesis/KinesisReceiver.scala | 20 +- .../streaming/kinesis/KinesisUtils.scala | 43 +-- .../SerializableCredentialsProvider.scala | 85 ------ .../kinesis/SparkAWSCredentials.scala | 182 ++++++++++++ .../JavaKinesisInputDStreamBuilderSuite.java | 63 +++++ .../KinesisInputDStreamBuilderSuite.scala | 115 ++++++++ .../kinesis/KinesisReceiverSuite.scala | 23 -- .../kinesis/KinesisStreamSuite.scala | 2 +- .../SparkAWSCredentialsBuilderSuite.scala | 100 +++++++ 11 files changed, 749 insertions(+), 149 deletions(-) delete mode 100644 external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/SerializableCredentialsProvider.scala create mode 100644 external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/SparkAWSCredentials.scala create mode 100644 external/kinesis-asl/src/test/java/org/apache/spark/streaming/kinesis/JavaKinesisInputDStreamBuilderSuite.java create mode 100644 external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisInputDStreamBuilderSuite.scala create mode 100644 external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/SparkAWSCredentialsBuilderSuite.scala diff --git a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala index 0f1790bddcc3..f31ebf1ec8da 100644 --- a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala +++ b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala @@ -82,8 +82,8 @@ class KinesisBackedBlockRDD[T: ClassTag]( @transient val arrayOfseqNumberRanges: Array[SequenceNumberRanges], @transient private val isBlockIdValid: Array[Boolean] = Array.empty, val retryTimeoutMs: Int = 10000, - val messageHandler: Record => T = KinesisUtils.defaultMessageHandler _, - val kinesisCredsProvider: SerializableCredentialsProvider = DefaultCredentialsProvider + val messageHandler: Record => T = KinesisInputDStream.defaultMessageHandler _, + val kinesisCreds: SparkAWSCredentials = DefaultCredentials ) extends BlockRDD[T](sc, _blockIds) { require(_blockIds.length == arrayOfseqNumberRanges.length, @@ -109,7 +109,7 @@ class KinesisBackedBlockRDD[T: ClassTag]( } def getBlockFromKinesis(): Iterator[T] = { - val credentials = kinesisCredsProvider.provider.getCredentials + val credentials = kinesisCreds.provider.getCredentials partition.seqNumberRanges.ranges.iterator.flatMap { range => new KinesisSequenceRangeIterator(credentials, endpointUrl, regionName, range, retryTimeoutMs).map(messageHandler) diff --git a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala index fbc6b99443ed..8970ad2bafda 100644 --- a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala +++ b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala @@ -22,24 +22,28 @@ import scala.reflect.ClassTag import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream import com.amazonaws.services.kinesis.model.Record +import org.apache.spark.annotation.InterfaceStability import org.apache.spark.rdd.RDD import org.apache.spark.storage.{BlockId, StorageLevel} import org.apache.spark.streaming.{Duration, StreamingContext, Time} +import org.apache.spark.streaming.api.java.JavaStreamingContext import org.apache.spark.streaming.dstream.ReceiverInputDStream import org.apache.spark.streaming.receiver.Receiver import org.apache.spark.streaming.scheduler.ReceivedBlockInfo private[kinesis] class KinesisInputDStream[T: ClassTag]( _ssc: StreamingContext, - streamName: String, - endpointUrl: String, - regionName: String, - initialPositionInStream: InitialPositionInStream, - checkpointAppName: String, - checkpointInterval: Duration, - storageLevel: StorageLevel, - messageHandler: Record => T, - kinesisCredsProvider: SerializableCredentialsProvider + val streamName: String, + val endpointUrl: String, + val regionName: String, + val initialPositionInStream: InitialPositionInStream, + val checkpointAppName: String, + val checkpointInterval: Duration, + val _storageLevel: StorageLevel, + val messageHandler: Record => T, + val kinesisCreds: SparkAWSCredentials, + val dynamoDBCreds: Option[SparkAWSCredentials], + val cloudWatchCreds: Option[SparkAWSCredentials] ) extends ReceiverInputDStream[T](_ssc) { private[streaming] @@ -61,7 +65,7 @@ private[kinesis] class KinesisInputDStream[T: ClassTag]( isBlockIdValid = isBlockIdValid, retryTimeoutMs = ssc.graph.batchDuration.milliseconds.toInt, messageHandler = messageHandler, - kinesisCredsProvider = kinesisCredsProvider) + kinesisCreds = kinesisCreds) } else { logWarning("Kinesis sequence number information was not present with some block metadata," + " it may not be possible to recover from failures") @@ -71,7 +75,238 @@ private[kinesis] class KinesisInputDStream[T: ClassTag]( override def getReceiver(): Receiver[T] = { new KinesisReceiver(streamName, endpointUrl, regionName, initialPositionInStream, - checkpointAppName, checkpointInterval, storageLevel, messageHandler, - kinesisCredsProvider) + checkpointAppName, checkpointInterval, _storageLevel, messageHandler, + kinesisCreds, dynamoDBCreds, cloudWatchCreds) } } + +@InterfaceStability.Evolving +object KinesisInputDStream { + /** + * Builder for [[KinesisInputDStream]] instances. + * + * @since 2.2.0 + */ + @InterfaceStability.Evolving + class Builder { + // Required params + private var streamingContext: Option[StreamingContext] = None + private var streamName: Option[String] = None + private var checkpointAppName: Option[String] = None + + // Params with defaults + private var endpointUrl: Option[String] = None + private var regionName: Option[String] = None + private var initialPositionInStream: Option[InitialPositionInStream] = None + private var checkpointInterval: Option[Duration] = None + private var storageLevel: Option[StorageLevel] = None + private var kinesisCredsProvider: Option[SparkAWSCredentials] = None + private var dynamoDBCredsProvider: Option[SparkAWSCredentials] = None + private var cloudWatchCredsProvider: Option[SparkAWSCredentials] = None + + /** + * Sets the StreamingContext that will be used to construct the Kinesis DStream. This is a + * required parameter. + * + * @param ssc [[StreamingContext]] used to construct Kinesis DStreams + * @return Reference to this [[KinesisInputDStream.Builder]] + */ + def streamingContext(ssc: StreamingContext): Builder = { + streamingContext = Option(ssc) + this + } + + /** + * Sets the StreamingContext that will be used to construct the Kinesis DStream. This is a + * required parameter. + * + * @param jssc [[JavaStreamingContext]] used to construct Kinesis DStreams + * @return Reference to this [[KinesisInputDStream.Builder]] + */ + def streamingContext(jssc: JavaStreamingContext): Builder = { + streamingContext = Option(jssc.ssc) + this + } + + /** + * Sets the name of the Kinesis stream that the DStream will read from. This is a required + * parameter. + * + * @param streamName Name of Kinesis stream that the DStream will read from + * @return Reference to this [[KinesisInputDStream.Builder]] + */ + def streamName(streamName: String): Builder = { + this.streamName = Option(streamName) + this + } + + /** + * Sets the KCL application name to use when checkpointing state to DynamoDB. This is a + * required parameter. + * + * @param appName Value to use for the KCL app name (used when creating the DynamoDB checkpoint + * table and when writing metrics to CloudWatch) + * @return Reference to this [[KinesisInputDStream.Builder]] + */ + def checkpointAppName(appName: String): Builder = { + checkpointAppName = Option(appName) + this + } + + /** + * Sets the AWS Kinesis endpoint URL. Defaults to "https://kinesis.us-east-1.amazonaws.com" if + * no custom value is specified + * + * @param url Kinesis endpoint URL to use + * @return Reference to this [[KinesisInputDStream.Builder]] + */ + def endpointUrl(url: String): Builder = { + endpointUrl = Option(url) + this + } + + /** + * Sets the AWS region to construct clients for. Defaults to "us-east-1" if no custom value + * is specified. + * + * @param regionName Name of AWS region to use (e.g. "us-west-2") + * @return Reference to this [[KinesisInputDStream.Builder]] + */ + def regionName(regionName: String): Builder = { + this.regionName = Option(regionName) + this + } + + /** + * Sets the initial position data is read from in the Kinesis stream. Defaults to + * [[InitialPositionInStream.LATEST]] if no custom value is specified. + * + * @param initialPosition InitialPositionInStream value specifying where Spark Streaming + * will start reading records in the Kinesis stream from + * @return Reference to this [[KinesisInputDStream.Builder]] + */ + def initialPositionInStream(initialPosition: InitialPositionInStream): Builder = { + initialPositionInStream = Option(initialPosition) + this + } + + /** + * Sets how often the KCL application state is checkpointed to DynamoDB. Defaults to the Spark + * Streaming batch interval if no custom value is specified. + * + * @param interval [[Duration]] specifying how often the KCL state should be checkpointed to + * DynamoDB. + * @return Reference to this [[KinesisInputDStream.Builder]] + */ + def checkpointInterval(interval: Duration): Builder = { + checkpointInterval = Option(interval) + this + } + + /** + * Sets the storage level of the blocks for the DStream created. Defaults to + * [[StorageLevel.MEMORY_AND_DISK_2]] if no custom value is specified. + * + * @param storageLevel [[StorageLevel]] to use for the DStream data blocks + * @return Reference to this [[KinesisInputDStream.Builder]] + */ + def storageLevel(storageLevel: StorageLevel): Builder = { + this.storageLevel = Option(storageLevel) + this + } + + /** + * Sets the [[SparkAWSCredentials]] to use for authenticating to the AWS Kinesis + * endpoint. Defaults to [[DefaultCredentialsProvider]] if no custom value is specified. + * + * @param credentials [[SparkAWSCredentials]] to use for Kinesis authentication + */ + def kinesisCredentials(credentials: SparkAWSCredentials): Builder = { + kinesisCredsProvider = Option(credentials) + this + } + + /** + * Sets the [[SparkAWSCredentials]] to use for authenticating to the AWS DynamoDB + * endpoint. Will use the same credentials used for AWS Kinesis if no custom value is set. + * + * @param credentials [[SparkAWSCredentials]] to use for DynamoDB authentication + */ + def dynamoDBCredentials(credentials: SparkAWSCredentials): Builder = { + dynamoDBCredsProvider = Option(credentials) + this + } + + /** + * Sets the [[SparkAWSCredentials]] to use for authenticating to the AWS CloudWatch + * endpoint. Will use the same credentials used for AWS Kinesis if no custom value is set. + * + * @param credentials [[SparkAWSCredentials]] to use for CloudWatch authentication + */ + def cloudWatchCredentials(credentials: SparkAWSCredentials): Builder = { + cloudWatchCredsProvider = Option(credentials) + this + } + + /** + * Create a new instance of [[KinesisInputDStream]] with configured parameters and the provided + * message handler. + * + * @param handler Function converting [[Record]] instances read by the KCL to DStream type [[T]] + * @return Instance of [[KinesisInputDStream]] constructed with configured parameters + */ + def buildWithMessageHandler[T: ClassTag]( + handler: Record => T): KinesisInputDStream[T] = { + val ssc = getRequiredParam(streamingContext, "streamingContext") + new KinesisInputDStream( + ssc, + getRequiredParam(streamName, "streamName"), + endpointUrl.getOrElse(DEFAULT_KINESIS_ENDPOINT_URL), + regionName.getOrElse(DEFAULT_KINESIS_REGION_NAME), + initialPositionInStream.getOrElse(DEFAULT_INITIAL_POSITION_IN_STREAM), + getRequiredParam(checkpointAppName, "checkpointAppName"), + checkpointInterval.getOrElse(ssc.graph.batchDuration), + storageLevel.getOrElse(DEFAULT_STORAGE_LEVEL), + handler, + kinesisCredsProvider.getOrElse(DefaultCredentials), + dynamoDBCredsProvider, + cloudWatchCredsProvider) + } + + /** + * Create a new instance of [[KinesisInputDStream]] with configured parameters and using the + * default message handler, which returns [[Array[Byte]]]. + * + * @return Instance of [[KinesisInputDStream]] constructed with configured parameters + */ + def build(): KinesisInputDStream[Array[Byte]] = buildWithMessageHandler(defaultMessageHandler) + + private def getRequiredParam[T](param: Option[T], paramName: String): T = param.getOrElse { + throw new IllegalArgumentException(s"No value provided for required parameter $paramName") + } + } + + /** + * Creates a [[KinesisInputDStream.Builder]] for constructing [[KinesisInputDStream]] instances. + * + * @since 2.2.0 + * + * @return [[KinesisInputDStream.Builder]] instance + */ + def builder: Builder = new Builder + + private[kinesis] def defaultMessageHandler(record: Record): Array[Byte] = { + if (record == null) return null + val byteBuffer = record.getData() + val byteArray = new Array[Byte](byteBuffer.remaining()) + byteBuffer.get(byteArray) + byteArray + } + + private[kinesis] val DEFAULT_KINESIS_ENDPOINT_URL: String = + "https://kinesis.us-east-1.amazonaws.com" + private[kinesis] val DEFAULT_KINESIS_REGION_NAME: String = "us-east-1" + private[kinesis] val DEFAULT_INITIAL_POSITION_IN_STREAM: InitialPositionInStream = + InitialPositionInStream.LATEST + private[kinesis] val DEFAULT_STORAGE_LEVEL: StorageLevel = StorageLevel.MEMORY_AND_DISK_2 +} diff --git a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala index 320728f4bb22..1026d0fcb59b 100644 --- a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala +++ b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala @@ -70,9 +70,14 @@ import org.apache.spark.util.Utils * See the Kinesis Spark Streaming documentation for more * details on the different types of checkpoints. * @param storageLevel Storage level to use for storing the received objects - * @param kinesisCredsProvider SerializableCredentialsProvider instance that will be used to - * generate the AWSCredentialsProvider instance used for KCL - * authorization. + * @param kinesisCreds SparkAWSCredentials instance that will be used to generate the + * AWSCredentialsProvider passed to the KCL to authorize Kinesis API calls. + * @param cloudWatchCreds Optional SparkAWSCredentials instance that will be used to generate the + * AWSCredentialsProvider passed to the KCL to authorize CloudWatch API + * calls. Will use kinesisCreds if value is None. + * @param dynamoDBCreds Optional SparkAWSCredentials instance that will be used to generate the + * AWSCredentialsProvider passed to the KCL to authorize DynamoDB API calls. + * Will use kinesisCreds if value is None. */ private[kinesis] class KinesisReceiver[T]( val streamName: String, @@ -83,7 +88,9 @@ private[kinesis] class KinesisReceiver[T]( checkpointInterval: Duration, storageLevel: StorageLevel, messageHandler: Record => T, - kinesisCredsProvider: SerializableCredentialsProvider) + kinesisCreds: SparkAWSCredentials, + dynamoDBCreds: Option[SparkAWSCredentials], + cloudWatchCreds: Option[SparkAWSCredentials]) extends Receiver[T](storageLevel) with Logging { receiver => /* @@ -140,10 +147,13 @@ private[kinesis] class KinesisReceiver[T]( workerId = Utils.localHostName() + ":" + UUID.randomUUID() kinesisCheckpointer = new KinesisCheckpointer(receiver, checkpointInterval, workerId) + val kinesisProvider = kinesisCreds.provider val kinesisClientLibConfiguration = new KinesisClientLibConfiguration( checkpointAppName, streamName, - kinesisCredsProvider.provider, + kinesisProvider, + dynamoDBCreds.map(_.provider).getOrElse(kinesisProvider), + cloudWatchCreds.map(_.provider).getOrElse(kinesisProvider), workerId) .withKinesisEndpoint(endpointUrl) .withInitialPositionInStream(initialPositionInStream) diff --git a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala index 2d777982e760..1298463bfba1 100644 --- a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala +++ b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala @@ -58,6 +58,7 @@ object KinesisUtils { * on the workers. See AWS documentation to understand how DefaultAWSCredentialsProviderChain * gets the AWS credentials. */ + @deprecated("Use KinesisInputDStream.builder instead", "2.2.0") def createStream[T: ClassTag]( ssc: StreamingContext, kinesisAppName: String, @@ -73,7 +74,7 @@ object KinesisUtils { ssc.withNamedScope("kinesis stream") { new KinesisInputDStream[T](ssc, streamName, endpointUrl, validateRegion(regionName), initialPositionInStream, kinesisAppName, checkpointInterval, storageLevel, - cleanedHandler, DefaultCredentialsProvider) + cleanedHandler, DefaultCredentials, None, None) } } @@ -108,6 +109,7 @@ object KinesisUtils { * is enabled. Make sure that your checkpoint directory is secure. */ // scalastyle:off + @deprecated("Use KinesisInputDStream.builder instead", "2.2.0") def createStream[T: ClassTag]( ssc: StreamingContext, kinesisAppName: String, @@ -123,12 +125,12 @@ object KinesisUtils { // scalastyle:on val cleanedHandler = ssc.sc.clean(messageHandler) ssc.withNamedScope("kinesis stream") { - val kinesisCredsProvider = BasicCredentialsProvider( + val kinesisCredsProvider = BasicCredentials( awsAccessKeyId = awsAccessKeyId, awsSecretKey = awsSecretKey) new KinesisInputDStream[T](ssc, streamName, endpointUrl, validateRegion(regionName), initialPositionInStream, kinesisAppName, checkpointInterval, storageLevel, - cleanedHandler, kinesisCredsProvider) + cleanedHandler, kinesisCredsProvider, None, None) } } @@ -169,6 +171,7 @@ object KinesisUtils { * is enabled. Make sure that your checkpoint directory is secure. */ // scalastyle:off + @deprecated("Use KinesisInputDStream.builder instead", "2.2.0") def createStream[T: ClassTag]( ssc: StreamingContext, kinesisAppName: String, @@ -187,16 +190,16 @@ object KinesisUtils { // scalastyle:on val cleanedHandler = ssc.sc.clean(messageHandler) ssc.withNamedScope("kinesis stream") { - val kinesisCredsProvider = STSCredentialsProvider( + val kinesisCredsProvider = STSCredentials( stsRoleArn = stsAssumeRoleArn, stsSessionName = stsSessionName, stsExternalId = Option(stsExternalId), - longLivedCredsProvider = BasicCredentialsProvider( + longLivedCreds = BasicCredentials( awsAccessKeyId = awsAccessKeyId, awsSecretKey = awsSecretKey)) new KinesisInputDStream[T](ssc, streamName, endpointUrl, validateRegion(regionName), initialPositionInStream, kinesisAppName, checkpointInterval, storageLevel, - cleanedHandler, kinesisCredsProvider) + cleanedHandler, kinesisCredsProvider, None, None) } } @@ -227,6 +230,7 @@ object KinesisUtils { * on the workers. See AWS documentation to understand how DefaultAWSCredentialsProviderChain * gets the AWS credentials. */ + @deprecated("Use KinesisInputDStream.builder instead", "2.2.0") def createStream( ssc: StreamingContext, kinesisAppName: String, @@ -240,7 +244,7 @@ object KinesisUtils { ssc.withNamedScope("kinesis stream") { new KinesisInputDStream[Array[Byte]](ssc, streamName, endpointUrl, validateRegion(regionName), initialPositionInStream, kinesisAppName, checkpointInterval, storageLevel, - defaultMessageHandler, DefaultCredentialsProvider) + KinesisInputDStream.defaultMessageHandler, DefaultCredentials, None, None) } } @@ -272,6 +276,7 @@ object KinesisUtils { * @note The given AWS credentials will get saved in DStream checkpoints if checkpointing * is enabled. Make sure that your checkpoint directory is secure. */ + @deprecated("Use KinesisInputDStream.builder instead", "2.2.0") def createStream( ssc: StreamingContext, kinesisAppName: String, @@ -284,12 +289,12 @@ object KinesisUtils { awsAccessKeyId: String, awsSecretKey: String): ReceiverInputDStream[Array[Byte]] = { ssc.withNamedScope("kinesis stream") { - val kinesisCredsProvider = BasicCredentialsProvider( + val kinesisCredsProvider = BasicCredentials( awsAccessKeyId = awsAccessKeyId, awsSecretKey = awsSecretKey) new KinesisInputDStream[Array[Byte]](ssc, streamName, endpointUrl, validateRegion(regionName), initialPositionInStream, kinesisAppName, checkpointInterval, storageLevel, - defaultMessageHandler, kinesisCredsProvider) + KinesisInputDStream.defaultMessageHandler, kinesisCredsProvider, None, None) } } @@ -323,6 +328,7 @@ object KinesisUtils { * on the workers. See AWS documentation to understand how DefaultAWSCredentialsProviderChain * gets the AWS credentials. */ + @deprecated("Use KinesisInputDStream.builder instead", "2.2.0") def createStream[T]( jssc: JavaStreamingContext, kinesisAppName: String, @@ -372,6 +378,7 @@ object KinesisUtils { * is enabled. Make sure that your checkpoint directory is secure. */ // scalastyle:off + @deprecated("Use KinesisInputDStream.builder instead", "2.2.0") def createStream[T]( jssc: JavaStreamingContext, kinesisAppName: String, @@ -431,6 +438,7 @@ object KinesisUtils { * is enabled. Make sure that your checkpoint directory is secure. */ // scalastyle:off + @deprecated("Use KinesisInputDStream.builder instead", "2.2.0") def createStream[T]( jssc: JavaStreamingContext, kinesisAppName: String, @@ -482,6 +490,7 @@ object KinesisUtils { * on the workers. See AWS documentation to understand how DefaultAWSCredentialsProviderChain * gets the AWS credentials. */ + @deprecated("Use KinesisInputDStream.builder instead", "2.2.0") def createStream( jssc: JavaStreamingContext, kinesisAppName: String, @@ -493,7 +502,8 @@ object KinesisUtils { storageLevel: StorageLevel ): JavaReceiverInputDStream[Array[Byte]] = { createStream[Array[Byte]](jssc.ssc, kinesisAppName, streamName, endpointUrl, regionName, - initialPositionInStream, checkpointInterval, storageLevel, defaultMessageHandler(_)) + initialPositionInStream, checkpointInterval, storageLevel, + KinesisInputDStream.defaultMessageHandler(_)) } /** @@ -524,6 +534,7 @@ object KinesisUtils { * @note The given AWS credentials will get saved in DStream checkpoints if checkpointing * is enabled. Make sure that your checkpoint directory is secure. */ + @deprecated("Use KinesisInputDStream.builder instead", "2.2.0") def createStream( jssc: JavaStreamingContext, kinesisAppName: String, @@ -537,7 +548,7 @@ object KinesisUtils { awsSecretKey: String): JavaReceiverInputDStream[Array[Byte]] = { createStream[Array[Byte]](jssc.ssc, kinesisAppName, streamName, endpointUrl, regionName, initialPositionInStream, checkpointInterval, storageLevel, - defaultMessageHandler(_), awsAccessKeyId, awsSecretKey) + KinesisInputDStream.defaultMessageHandler(_), awsAccessKeyId, awsSecretKey) } private def validateRegion(regionName: String): String = { @@ -545,14 +556,6 @@ object KinesisUtils { throw new IllegalArgumentException(s"Region name '$regionName' is not valid") } } - - private[kinesis] def defaultMessageHandler(record: Record): Array[Byte] = { - if (record == null) return null - val byteBuffer = record.getData() - val byteArray = new Array[Byte](byteBuffer.remaining()) - byteBuffer.get(byteArray) - byteArray - } } /** @@ -597,7 +600,7 @@ private class KinesisUtilsPythonHelper { validateAwsCreds(awsAccessKeyId, awsSecretKey) KinesisUtils.createStream(jssc.ssc, kinesisAppName, streamName, endpointUrl, regionName, getInitialPositionInStream(initialPositionInStream), checkpointInterval, storageLevel, - KinesisUtils.defaultMessageHandler(_), awsAccessKeyId, awsSecretKey, + KinesisInputDStream.defaultMessageHandler(_), awsAccessKeyId, awsSecretKey, stsAssumeRoleArn, stsSessionName, stsExternalId) } else { validateAwsCreds(awsAccessKeyId, awsSecretKey) diff --git a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/SerializableCredentialsProvider.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/SerializableCredentialsProvider.scala deleted file mode 100644 index aa6fe12edf74..000000000000 --- a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/SerializableCredentialsProvider.scala +++ /dev/null @@ -1,85 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.streaming.kinesis - -import scala.collection.JavaConverters._ - -import com.amazonaws.auth._ - -import org.apache.spark.internal.Logging - -/** - * Serializable interface providing a method executors can call to obtain an - * AWSCredentialsProvider instance for authenticating to AWS services. - */ -private[kinesis] sealed trait SerializableCredentialsProvider extends Serializable { - /** - * Return an AWSCredentialProvider instance that can be used by the Kinesis Client - * Library to authenticate to AWS services (Kinesis, CloudWatch and DynamoDB). - */ - def provider: AWSCredentialsProvider -} - -/** Returns DefaultAWSCredentialsProviderChain for authentication. */ -private[kinesis] final case object DefaultCredentialsProvider - extends SerializableCredentialsProvider { - - def provider: AWSCredentialsProvider = new DefaultAWSCredentialsProviderChain -} - -/** - * Returns AWSStaticCredentialsProvider constructed using basic AWS keypair. Falls back to using - * DefaultAWSCredentialsProviderChain if unable to construct a AWSCredentialsProviderChain - * instance with the provided arguments (e.g. if they are null). - */ -private[kinesis] final case class BasicCredentialsProvider( - awsAccessKeyId: String, - awsSecretKey: String) extends SerializableCredentialsProvider with Logging { - - def provider: AWSCredentialsProvider = try { - new AWSStaticCredentialsProvider(new BasicAWSCredentials(awsAccessKeyId, awsSecretKey)) - } catch { - case e: IllegalArgumentException => - logWarning("Unable to construct AWSStaticCredentialsProvider with provided keypair; " + - "falling back to DefaultAWSCredentialsProviderChain.", e) - new DefaultAWSCredentialsProviderChain - } -} - -/** - * Returns an STSAssumeRoleSessionCredentialsProvider instance which assumes an IAM - * role in order to authenticate against resources in an external account. - */ -private[kinesis] final case class STSCredentialsProvider( - stsRoleArn: String, - stsSessionName: String, - stsExternalId: Option[String] = None, - longLivedCredsProvider: SerializableCredentialsProvider = DefaultCredentialsProvider) - extends SerializableCredentialsProvider { - - def provider: AWSCredentialsProvider = { - val builder = new STSAssumeRoleSessionCredentialsProvider.Builder(stsRoleArn, stsSessionName) - .withLongLivedCredentialsProvider(longLivedCredsProvider.provider) - stsExternalId match { - case Some(stsExternalId) => - builder.withExternalId(stsExternalId) - .build() - case None => - builder.build() - } - } -} diff --git a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/SparkAWSCredentials.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/SparkAWSCredentials.scala new file mode 100644 index 000000000000..9facfe8ff2b0 --- /dev/null +++ b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/SparkAWSCredentials.scala @@ -0,0 +1,182 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.streaming.kinesis + +import scala.collection.JavaConverters._ + +import com.amazonaws.auth._ + +import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.internal.Logging + +/** + * Serializable interface providing a method executors can call to obtain an + * AWSCredentialsProvider instance for authenticating to AWS services. + */ +private[kinesis] sealed trait SparkAWSCredentials extends Serializable { + /** + * Return an AWSCredentialProvider instance that can be used by the Kinesis Client + * Library to authenticate to AWS services (Kinesis, CloudWatch and DynamoDB). + */ + def provider: AWSCredentialsProvider +} + +/** Returns DefaultAWSCredentialsProviderChain for authentication. */ +private[kinesis] final case object DefaultCredentials extends SparkAWSCredentials { + + def provider: AWSCredentialsProvider = new DefaultAWSCredentialsProviderChain +} + +/** + * Returns AWSStaticCredentialsProvider constructed using basic AWS keypair. Falls back to using + * DefaultCredentialsProviderChain if unable to construct a AWSCredentialsProviderChain + * instance with the provided arguments (e.g. if they are null). + */ +private[kinesis] final case class BasicCredentials( + awsAccessKeyId: String, + awsSecretKey: String) extends SparkAWSCredentials with Logging { + + def provider: AWSCredentialsProvider = try { + new AWSStaticCredentialsProvider(new BasicAWSCredentials(awsAccessKeyId, awsSecretKey)) + } catch { + case e: IllegalArgumentException => + logWarning("Unable to construct AWSStaticCredentialsProvider with provided keypair; " + + "falling back to DefaultCredentialsProviderChain.", e) + new DefaultAWSCredentialsProviderChain + } +} + +/** + * Returns an STSAssumeRoleSessionCredentialsProvider instance which assumes an IAM + * role in order to authenticate against resources in an external account. + */ +private[kinesis] final case class STSCredentials( + stsRoleArn: String, + stsSessionName: String, + stsExternalId: Option[String] = None, + longLivedCreds: SparkAWSCredentials = DefaultCredentials) + extends SparkAWSCredentials { + + def provider: AWSCredentialsProvider = { + val builder = new STSAssumeRoleSessionCredentialsProvider.Builder(stsRoleArn, stsSessionName) + .withLongLivedCredentialsProvider(longLivedCreds.provider) + stsExternalId match { + case Some(stsExternalId) => + builder.withExternalId(stsExternalId) + .build() + case None => + builder.build() + } + } +} + +@InterfaceStability.Evolving +object SparkAWSCredentials { + /** + * Builder for [[SparkAWSCredentials]] instances. + * + * @since 2.2.0 + */ + @InterfaceStability.Evolving + class Builder { + private var basicCreds: Option[BasicCredentials] = None + private var stsCreds: Option[STSCredentials] = None + + // scalastyle:off + /** + * Use a basic AWS keypair for long-lived authorization. + * + * @note The given AWS keypair will be saved in DStream checkpoints if checkpointing is + * enabled. Make sure that your checkpoint directory is secure. Prefer using the + * [[http://docs.aws.amazon.com/sdk-for-java/v1/developer-guide/credentials.html#credentials-default default provider chain]] + * instead if possible. + * + * @param accessKeyId AWS access key ID + * @param secretKey AWS secret key + * @return Reference to this [[SparkAWSCredentials.Builder]] + */ + // scalastyle:on + def basicCredentials(accessKeyId: String, secretKey: String): Builder = { + basicCreds = Option(BasicCredentials( + awsAccessKeyId = accessKeyId, + awsSecretKey = secretKey)) + this + } + + /** + * Use STS to assume an IAM role for temporary session-based authentication. Will use configured + * long-lived credentials for authorizing to STS itself (either the default provider chain + * or a configured keypair). + * + * @param roleArn ARN of IAM role to assume via STS + * @param sessionName Name to use for the STS session + * @return Reference to this [[SparkAWSCredentials.Builder]] + */ + def stsCredentials(roleArn: String, sessionName: String): Builder = { + stsCreds = Option(STSCredentials(stsRoleArn = roleArn, stsSessionName = sessionName)) + this + } + + /** + * Use STS to assume an IAM role for temporary session-based authentication. Will use configured + * long-lived credentials for authorizing to STS itself (either the default provider chain + * or a configured keypair). STS will validate the provided external ID with the one defined + * in the trust policy of the IAM role to be assumed (if one is present). + * + * @param roleArn ARN of IAM role to assume via STS + * @param sessionName Name to use for the STS session + * @param externalId External ID to validate against assumed IAM role's trust policy + * @return Reference to this [[SparkAWSCredentials.Builder]] + */ + def stsCredentials(roleArn: String, sessionName: String, externalId: String): Builder = { + stsCreds = Option(STSCredentials( + stsRoleArn = roleArn, + stsSessionName = sessionName, + stsExternalId = Option(externalId))) + this + } + + /** + * Returns the appropriate instance of [[SparkAWSCredentials]] given the configured + * parameters. + * + * - The long-lived credentials will either be [[DefaultCredentials]] or [[BasicCredentials]] + * if they were provided. + * + * - If STS credentials were provided, the configured long-lived credentials will be added to + * them and the result will be returned. + * + * - The long-lived credentials will be returned otherwise. + * + * @return [[SparkAWSCredentials]] to use for configured parameters + */ + def build(): SparkAWSCredentials = + stsCreds.map(_.copy(longLivedCreds = longLivedCreds)).getOrElse(longLivedCreds) + + private def longLivedCreds: SparkAWSCredentials = basicCreds.getOrElse(DefaultCredentials) + } + + /** + * Creates a [[SparkAWSCredentials.Builder]] for constructing + * [[SparkAWSCredentials]] instances. + * + * @since 2.2.0 + * + * @return [[SparkAWSCredentials.Builder]] instance + */ + def builder: Builder = new Builder +} diff --git a/external/kinesis-asl/src/test/java/org/apache/spark/streaming/kinesis/JavaKinesisInputDStreamBuilderSuite.java b/external/kinesis-asl/src/test/java/org/apache/spark/streaming/kinesis/JavaKinesisInputDStreamBuilderSuite.java new file mode 100644 index 000000000000..7205f6e27266 --- /dev/null +++ b/external/kinesis-asl/src/test/java/org/apache/spark/streaming/kinesis/JavaKinesisInputDStreamBuilderSuite.java @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kinesis; + +import org.junit.Test; + +import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream; + +import org.apache.spark.storage.StorageLevel; +import org.apache.spark.streaming.Duration; +import org.apache.spark.streaming.Seconds; +import org.apache.spark.streaming.LocalJavaStreamingContext; +import org.apache.spark.streaming.api.java.JavaDStream; + +public class JavaKinesisInputDStreamBuilderSuite extends LocalJavaStreamingContext { + /** + * Basic test to ensure that the KinesisDStream.Builder interface is accessible from Java. + */ + @Test + public void testJavaKinesisDStreamBuilder() { + String streamName = "a-very-nice-stream-name"; + String endpointUrl = "https://kinesis.us-west-2.amazonaws.com"; + String region = "us-west-2"; + InitialPositionInStream initialPosition = InitialPositionInStream.TRIM_HORIZON; + String appName = "a-very-nice-kinesis-app"; + Duration checkpointInterval = Seconds.apply(30); + StorageLevel storageLevel = StorageLevel.MEMORY_ONLY(); + + KinesisInputDStream kinesisDStream = KinesisInputDStream.builder() + .streamingContext(ssc) + .streamName(streamName) + .endpointUrl(endpointUrl) + .regionName(region) + .initialPositionInStream(initialPosition) + .checkpointAppName(appName) + .checkpointInterval(checkpointInterval) + .storageLevel(storageLevel) + .build(); + assert(kinesisDStream.streamName() == streamName); + assert(kinesisDStream.endpointUrl() == endpointUrl); + assert(kinesisDStream.regionName() == region); + assert(kinesisDStream.initialPositionInStream() == initialPosition); + assert(kinesisDStream.checkpointAppName() == appName); + assert(kinesisDStream.checkpointInterval() == checkpointInterval); + assert(kinesisDStream._storageLevel() == storageLevel); + ssc.stop(); + } +} diff --git a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisInputDStreamBuilderSuite.scala b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisInputDStreamBuilderSuite.scala new file mode 100644 index 000000000000..1c130654f3f9 --- /dev/null +++ b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisInputDStreamBuilderSuite.scala @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kinesis + +import java.lang.IllegalArgumentException + +import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream +import org.scalatest.BeforeAndAfterEach +import org.scalatest.mock.MockitoSugar + +import org.apache.spark.SparkFunSuite +import org.apache.spark.storage.StorageLevel +import org.apache.spark.streaming.{Seconds, StreamingContext, TestSuiteBase} + +class KinesisInputDStreamBuilderSuite extends TestSuiteBase with BeforeAndAfterEach + with MockitoSugar { + import KinesisInputDStream._ + + private val ssc = new StreamingContext(conf, batchDuration) + private val streamName = "a-very-nice-kinesis-stream-name" + private val checkpointAppName = "a-very-nice-kcl-app-name" + private def baseBuilder = KinesisInputDStream.builder + private def builder = baseBuilder.streamingContext(ssc) + .streamName(streamName) + .checkpointAppName(checkpointAppName) + + override def afterAll(): Unit = { + ssc.stop() + } + + test("should raise an exception if the StreamingContext is missing") { + intercept[IllegalArgumentException] { + baseBuilder.streamName(streamName).checkpointAppName(checkpointAppName).build() + } + } + + test("should raise an exception if the stream name is missing") { + intercept[IllegalArgumentException] { + baseBuilder.streamingContext(ssc).checkpointAppName(checkpointAppName).build() + } + } + + test("should raise an exception if the checkpoint app name is missing") { + intercept[IllegalArgumentException] { + baseBuilder.streamingContext(ssc).streamName(streamName).build() + } + } + + test("should propagate required values to KinesisInputDStream") { + val dstream = builder.build() + assert(dstream.context == ssc) + assert(dstream.streamName == streamName) + assert(dstream.checkpointAppName == checkpointAppName) + } + + test("should propagate default values to KinesisInputDStream") { + val dstream = builder.build() + assert(dstream.endpointUrl == DEFAULT_KINESIS_ENDPOINT_URL) + assert(dstream.regionName == DEFAULT_KINESIS_REGION_NAME) + assert(dstream.initialPositionInStream == DEFAULT_INITIAL_POSITION_IN_STREAM) + assert(dstream.checkpointInterval == batchDuration) + assert(dstream._storageLevel == DEFAULT_STORAGE_LEVEL) + assert(dstream.kinesisCreds == DefaultCredentials) + assert(dstream.dynamoDBCreds == None) + assert(dstream.cloudWatchCreds == None) + } + + test("should propagate custom non-auth values to KinesisInputDStream") { + val customEndpointUrl = "https://kinesis.us-west-2.amazonaws.com" + val customRegion = "us-west-2" + val customInitialPosition = InitialPositionInStream.TRIM_HORIZON + val customAppName = "a-very-nice-kinesis-app" + val customCheckpointInterval = Seconds(30) + val customStorageLevel = StorageLevel.MEMORY_ONLY + val customKinesisCreds = mock[SparkAWSCredentials] + val customDynamoDBCreds = mock[SparkAWSCredentials] + val customCloudWatchCreds = mock[SparkAWSCredentials] + + val dstream = builder + .endpointUrl(customEndpointUrl) + .regionName(customRegion) + .initialPositionInStream(customInitialPosition) + .checkpointAppName(customAppName) + .checkpointInterval(customCheckpointInterval) + .storageLevel(customStorageLevel) + .kinesisCredentials(customKinesisCreds) + .dynamoDBCredentials(customDynamoDBCreds) + .cloudWatchCredentials(customCloudWatchCreds) + .build() + assert(dstream.endpointUrl == customEndpointUrl) + assert(dstream.regionName == customRegion) + assert(dstream.initialPositionInStream == customInitialPosition) + assert(dstream.checkpointAppName == customAppName) + assert(dstream.checkpointInterval == customCheckpointInterval) + assert(dstream._storageLevel == customStorageLevel) + assert(dstream.kinesisCreds == customKinesisCreds) + assert(dstream.dynamoDBCreds == Option(customDynamoDBCreds)) + assert(dstream.cloudWatchCreds == Option(customCloudWatchCreds)) + } +} diff --git a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala index deb411d73e58..3b14c8471e20 100644 --- a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala +++ b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala @@ -31,7 +31,6 @@ import org.scalatest.{BeforeAndAfter, Matchers} import org.scalatest.mock.MockitoSugar import org.apache.spark.streaming.{Duration, TestSuiteBase} -import org.apache.spark.util.Utils /** * Suite of Kinesis streaming receiver tests focusing mostly on the KinesisRecordProcessor @@ -62,28 +61,6 @@ class KinesisReceiverSuite extends TestSuiteBase with Matchers with BeforeAndAft checkpointerMock = mock[IRecordProcessorCheckpointer] } - test("check serializability of credential provider classes") { - Utils.deserialize[BasicCredentialsProvider]( - Utils.serialize(BasicCredentialsProvider( - awsAccessKeyId = "x", - awsSecretKey = "y"))) - - Utils.deserialize[STSCredentialsProvider]( - Utils.serialize(STSCredentialsProvider( - stsRoleArn = "fakeArn", - stsSessionName = "fakeSessionName", - stsExternalId = Some("fakeExternalId")))) - - Utils.deserialize[STSCredentialsProvider]( - Utils.serialize(STSCredentialsProvider( - stsRoleArn = "fakeArn", - stsSessionName = "fakeSessionName", - stsExternalId = Some("fakeExternalId"), - longLivedCredsProvider = BasicCredentialsProvider( - awsAccessKeyId = "x", - awsSecretKey = "y")))) - } - test("process records including store and set checkpointer") { when(receiverMock.isStopped()).thenReturn(false) when(receiverMock.getCurrentLimit).thenReturn(Int.MaxValue) diff --git a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala index afb55c84f81f..ed7e35805026 100644 --- a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala +++ b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala @@ -138,7 +138,7 @@ abstract class KinesisStreamTests(aggregateTestData: Boolean) extends KinesisFun assert(kinesisRDD.regionName === dummyRegionName) assert(kinesisRDD.endpointUrl === dummyEndpointUrl) assert(kinesisRDD.retryTimeoutMs === batchDuration.milliseconds) - assert(kinesisRDD.kinesisCredsProvider === BasicCredentialsProvider( + assert(kinesisRDD.kinesisCreds === BasicCredentials( awsAccessKeyId = dummyAWSAccessKey, awsSecretKey = dummyAWSSecretKey)) assert(nonEmptyRDD.partitions.size === blockInfos.size) diff --git a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/SparkAWSCredentialsBuilderSuite.scala b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/SparkAWSCredentialsBuilderSuite.scala new file mode 100644 index 000000000000..f579c2c3a679 --- /dev/null +++ b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/SparkAWSCredentialsBuilderSuite.scala @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kinesis + +import org.apache.spark.streaming.TestSuiteBase +import org.apache.spark.util.Utils + +class SparkAWSCredentialsBuilderSuite extends TestSuiteBase { + private def builder = SparkAWSCredentials.builder + + private val basicCreds = BasicCredentials( + awsAccessKeyId = "a-very-nice-access-key", + awsSecretKey = "a-very-nice-secret-key") + + private val stsCreds = STSCredentials( + stsRoleArn = "a-very-nice-role-arn", + stsSessionName = "a-very-nice-secret-key", + stsExternalId = Option("a-very-nice-external-id"), + longLivedCreds = basicCreds) + + test("should build DefaultCredentials when given no params") { + assert(builder.build() == DefaultCredentials) + } + + test("should build BasicCredentials") { + assertResult(basicCreds) { + builder.basicCredentials(basicCreds.awsAccessKeyId, basicCreds.awsSecretKey) + .build() + } + } + + test("should build STSCredentials") { + // No external ID, default long-lived creds + assertResult(stsCreds.copy(stsExternalId = None, longLivedCreds = DefaultCredentials)) { + builder.stsCredentials(stsCreds.stsRoleArn, stsCreds.stsSessionName) + .build() + } + // Default long-lived creds + assertResult(stsCreds.copy(longLivedCreds = DefaultCredentials)) { + builder.stsCredentials( + stsCreds.stsRoleArn, + stsCreds.stsSessionName, + stsCreds.stsExternalId.get) + .build() + } + // No external ID, basic keypair for long-lived creds + assertResult(stsCreds.copy(stsExternalId = None)) { + builder.stsCredentials(stsCreds.stsRoleArn, stsCreds.stsSessionName) + .basicCredentials(basicCreds.awsAccessKeyId, basicCreds.awsSecretKey) + .build() + } + // Basic keypair for long-lived creds + assertResult(stsCreds) { + builder.stsCredentials( + stsCreds.stsRoleArn, + stsCreds.stsSessionName, + stsCreds.stsExternalId.get) + .basicCredentials(basicCreds.awsAccessKeyId, basicCreds.awsSecretKey) + .build() + } + // Order shouldn't matter + assertResult(stsCreds) { + builder.basicCredentials(basicCreds.awsAccessKeyId, basicCreds.awsSecretKey) + .stsCredentials( + stsCreds.stsRoleArn, + stsCreds.stsSessionName, + stsCreds.stsExternalId.get) + .build() + } + } + + test("SparkAWSCredentials classes should be serializable") { + assertResult(basicCreds) { + Utils.deserialize[BasicCredentials](Utils.serialize(basicCreds)) + } + assertResult(stsCreds) { + Utils.deserialize[STSCredentials](Utils.serialize(stsCreds)) + } + // Will also test if DefaultCredentials can be serialized + val stsDefaultCreds = stsCreds.copy(longLivedCreds = DefaultCredentials) + assertResult(stsDefaultCreds) { + Utils.deserialize[STSCredentials](Utils.serialize(stsDefaultCreds)) + } + } +} From e8810b73c495b6d437dd3b9bb334762126b3c063 Mon Sep 17 00:00:00 2001 From: sethah Date: Fri, 24 Mar 2017 20:32:42 +0000 Subject: [PATCH 123/512] [SPARK-17471][ML] Add compressed method to ML matrices ## What changes were proposed in this pull request? This patch adds a `compressed` method to ML `Matrix` class, which returns the minimal storage representation of the matrix - either sparse or dense. Because the space occupied by a sparse matrix is dependent upon its layout (i.e. column major or row major), this method must consider both cases. It may also be useful to force the layout to be column or row major beforehand, so an overload is added which takes in a `columnMajor: Boolean` parameter. The compressed implementation relies upon two new abstract methods `toDense(columnMajor: Boolean)` and `toSparse(columnMajor: Boolean)`, similar to the compressed method implemented in the `Vector` class. These methods also allow the layout of the resulting matrix to be specified via the `columnMajor` parameter. More detail on the new methods is given below. ## How was this patch tested? Added many new unit tests ## New methods (summary, not exhaustive list) **Matrix trait** - `private[ml] def toDenseMatrix(columnMajor: Boolean): DenseMatrix` (abstract) - converts the matrix (either sparse or dense) to dense format - `private[ml] def toSparseMatrix(columnMajor: Boolean): SparseMatrix` (abstract) - converts the matrix (either sparse or dense) to sparse format - `def toDense: DenseMatrix = toDense(true)` - converts the matrix (either sparse or dense) to dense format in column major layout - `def toSparse: SparseMatrix = toSparse(true)` - converts the matrix (either sparse or dense) to sparse format in column major layout - `def compressed: Matrix` - finds the minimum space representation of this matrix, considering both column and row major layouts, and converts it - `def compressed(columnMajor: Boolean): Matrix` - finds the minimum space representation of this matrix considering only column OR row major, and converts it **DenseMatrix class** - `private[ml] def toDenseMatrix(columnMajor: Boolean): DenseMatrix` - converts the dense matrix to a dense matrix, optionally changing the layout (data is NOT duplicated if the layouts are the same) - `private[ml] def toSparseMatrix(columnMajor: Boolean): SparseMatrix` - converts the dense matrix to sparse matrix, using the specified layout **SparseMatrix class** - `private[ml] def toDenseMatrix(columnMajor: Boolean): DenseMatrix` - converts the sparse matrix to a dense matrix, using the specified layout - `private[ml] def toSparseMatrix(columnMajors: Boolean): SparseMatrix` - converts the sparse matrix to sparse matrix. If the sparse matrix contains any explicit zeros, they are removed. If the layout requested does not match the current layout, data is copied to a new representation. If the layouts match and no explicit zeros exist, the current matrix is returned. Author: sethah Closes #15628 from sethah/matrix_compress. --- .../org/apache/spark/ml/linalg/Matrices.scala | 274 ++++++++++-- .../spark/ml/linalg/MatricesSuite.scala | 420 +++++++++++++++++- .../apache/spark/ml/linalg/VectorsSuite.scala | 5 + project/MimaExcludes.scala | 20 +- 4 files changed, 673 insertions(+), 46 deletions(-) diff --git a/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Matrices.scala b/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Matrices.scala index d9ffdeb797fb..07f3bc27280b 100644 --- a/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Matrices.scala +++ b/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Matrices.scala @@ -44,6 +44,12 @@ sealed trait Matrix extends Serializable { @Since("2.0.0") val isTransposed: Boolean = false + /** Indicates whether the values backing this matrix are arranged in column major order. */ + private[ml] def isColMajor: Boolean = !isTransposed + + /** Indicates whether the values backing this matrix are arranged in row major order. */ + private[ml] def isRowMajor: Boolean = isTransposed + /** Converts to a dense array in column major. */ @Since("2.0.0") def toArray: Array[Double] = { @@ -148,7 +154,8 @@ sealed trait Matrix extends Serializable { * and column indices respectively with the type `Int`, and the final parameter is the * corresponding value in the matrix with type `Double`. */ - private[spark] def foreachActive(f: (Int, Int, Double) => Unit) + @Since("2.2.0") + def foreachActive(f: (Int, Int, Double) => Unit): Unit /** * Find the number of non-zero active values. @@ -161,6 +168,116 @@ sealed trait Matrix extends Serializable { */ @Since("2.0.0") def numActives: Int + + /** + * Converts this matrix to a sparse matrix. + * + * @param colMajor Whether the values of the resulting sparse matrix should be in column major + * or row major order. If `false`, resulting matrix will be row major. + */ + private[ml] def toSparseMatrix(colMajor: Boolean): SparseMatrix + + /** + * Converts this matrix to a sparse matrix in column major order. + */ + @Since("2.2.0") + def toSparseColMajor: SparseMatrix = toSparseMatrix(colMajor = true) + + /** + * Converts this matrix to a sparse matrix in row major order. + */ + @Since("2.2.0") + def toSparseRowMajor: SparseMatrix = toSparseMatrix(colMajor = false) + + /** + * Converts this matrix to a sparse matrix while maintaining the layout of the current matrix. + */ + @Since("2.2.0") + def toSparse: SparseMatrix = toSparseMatrix(colMajor = isColMajor) + + /** + * Converts this matrix to a dense matrix. + * + * @param colMajor Whether the values of the resulting dense matrix should be in column major + * or row major order. If `false`, resulting matrix will be row major. + */ + private[ml] def toDenseMatrix(colMajor: Boolean): DenseMatrix + + /** + * Converts this matrix to a dense matrix while maintaining the layout of the current matrix. + */ + @Since("2.2.0") + def toDense: DenseMatrix = toDenseMatrix(colMajor = isColMajor) + + /** + * Converts this matrix to a dense matrix in row major order. + */ + @Since("2.2.0") + def toDenseRowMajor: DenseMatrix = toDenseMatrix(colMajor = false) + + /** + * Converts this matrix to a dense matrix in column major order. + */ + @Since("2.2.0") + def toDenseColMajor: DenseMatrix = toDenseMatrix(colMajor = true) + + /** + * Returns a matrix in dense or sparse column major format, whichever uses less storage. + */ + @Since("2.2.0") + def compressedColMajor: Matrix = { + if (getDenseSizeInBytes <= getSparseSizeInBytes(colMajor = true)) { + this.toDenseColMajor + } else { + this.toSparseColMajor + } + } + + /** + * Returns a matrix in dense or sparse row major format, whichever uses less storage. + */ + @Since("2.2.0") + def compressedRowMajor: Matrix = { + if (getDenseSizeInBytes <= getSparseSizeInBytes(colMajor = false)) { + this.toDenseRowMajor + } else { + this.toSparseRowMajor + } + } + + /** + * Returns a matrix in dense column major, dense row major, sparse row major, or sparse column + * major format, whichever uses less storage. When dense representation is optimal, it maintains + * the current layout order. + */ + @Since("2.2.0") + def compressed: Matrix = { + val cscSize = getSparseSizeInBytes(colMajor = true) + val csrSize = getSparseSizeInBytes(colMajor = false) + if (getDenseSizeInBytes <= math.min(cscSize, csrSize)) { + // dense matrix size is the same for column major and row major, so maintain current layout + this.toDense + } else if (cscSize <= csrSize) { + this.toSparseColMajor + } else { + this.toSparseRowMajor + } + } + + /** Gets the size of the dense representation of this `Matrix`. */ + private[ml] def getDenseSizeInBytes: Long = { + Matrices.getDenseSize(numCols, numRows) + } + + /** Gets the size of the minimal sparse representation of this `Matrix`. */ + private[ml] def getSparseSizeInBytes(colMajor: Boolean): Long = { + val nnz = numNonzeros + val numPtrs = if (colMajor) numCols + 1L else numRows + 1L + Matrices.getSparseSize(nnz, numPtrs) + } + + /** Gets the current size in bytes of this `Matrix`. Useful for testing */ + private[ml] def getSizeInBytes: Long } /** @@ -258,7 +375,7 @@ class DenseMatrix @Since("2.0.0") ( override def transpose: DenseMatrix = new DenseMatrix(numCols, numRows, values, !isTransposed) - private[spark] override def foreachActive(f: (Int, Int, Double) => Unit): Unit = { + override def foreachActive(f: (Int, Int, Double) => Unit): Unit = { if (!isTransposed) { // outer loop over columns var j = 0 @@ -291,31 +408,49 @@ class DenseMatrix @Since("2.0.0") ( override def numActives: Int = values.length /** - * Generate a `SparseMatrix` from the given `DenseMatrix`. The new matrix will have isTransposed - * set to false. + * Generate a `SparseMatrix` from the given `DenseMatrix`. + * + * @param colMajor Whether the resulting `SparseMatrix` values will be in column major order. */ - @Since("2.0.0") - def toSparse: SparseMatrix = { - val spVals: MArrayBuilder[Double] = new MArrayBuilder.ofDouble - val colPtrs: Array[Int] = new Array[Int](numCols + 1) - val rowIndices: MArrayBuilder[Int] = new MArrayBuilder.ofInt - var nnz = 0 - var j = 0 - while (j < numCols) { - var i = 0 - while (i < numRows) { - val v = values(index(i, j)) - if (v != 0.0) { - rowIndices += i - spVals += v - nnz += 1 + private[ml] override def toSparseMatrix(colMajor: Boolean): SparseMatrix = { + if (!colMajor) this.transpose.toSparseColMajor.transpose + else { + val spVals: MArrayBuilder[Double] = new MArrayBuilder.ofDouble + val colPtrs: Array[Int] = new Array[Int](numCols + 1) + val rowIndices: MArrayBuilder[Int] = new MArrayBuilder.ofInt + var nnz = 0 + var j = 0 + while (j < numCols) { + var i = 0 + while (i < numRows) { + val v = values(index(i, j)) + if (v != 0.0) { + rowIndices += i + spVals += v + nnz += 1 + } + i += 1 } - i += 1 + j += 1 + colPtrs(j) = nnz } - j += 1 - colPtrs(j) = nnz + new SparseMatrix(numRows, numCols, colPtrs, rowIndices.result(), spVals.result()) + } + } + + /** + * Generate a `DenseMatrix` from this `DenseMatrix`. + * + * @param colMajor Whether the resulting `DenseMatrix` values will be in column major order. + */ + private[ml] override def toDenseMatrix(colMajor: Boolean): DenseMatrix = { + if (isRowMajor && colMajor) { + new DenseMatrix(numRows, numCols, this.toArray, isTransposed = false) + } else if (isColMajor && !colMajor) { + new DenseMatrix(numRows, numCols, this.transpose.toArray, isTransposed = true) + } else { + this } - new SparseMatrix(numRows, numCols, colPtrs, rowIndices.result(), spVals.result()) } override def colIter: Iterator[Vector] = { @@ -331,6 +466,8 @@ class DenseMatrix @Since("2.0.0") ( } } } + + private[ml] def getSizeInBytes: Long = Matrices.getDenseSize(numCols, numRows) } /** @@ -560,7 +697,7 @@ class SparseMatrix @Since("2.0.0") ( override def transpose: SparseMatrix = new SparseMatrix(numCols, numRows, colPtrs, rowIndices, values, !isTransposed) - private[spark] override def foreachActive(f: (Int, Int, Double) => Unit): Unit = { + override def foreachActive(f: (Int, Int, Double) => Unit): Unit = { if (!isTransposed) { var j = 0 while (j < numCols) { @@ -587,18 +724,67 @@ class SparseMatrix @Since("2.0.0") ( } } + override def numNonzeros: Int = values.count(_ != 0) + + override def numActives: Int = values.length + /** - * Generate a `DenseMatrix` from the given `SparseMatrix`. The new matrix will have isTransposed - * set to false. + * Generate a `SparseMatrix` from this `SparseMatrix`, removing explicit zero values if they + * exist. + * + * @param colMajor Whether or not the resulting `SparseMatrix` values are in column major + * order. */ - @Since("2.0.0") - def toDense: DenseMatrix = { - new DenseMatrix(numRows, numCols, toArray) + private[ml] override def toSparseMatrix(colMajor: Boolean): SparseMatrix = { + if (isColMajor && !colMajor) { + // it is col major and we want row major, use breeze to remove explicit zeros + val breezeTransposed = asBreeze.asInstanceOf[BSM[Double]].t + Matrices.fromBreeze(breezeTransposed).transpose.asInstanceOf[SparseMatrix] + } else if (isRowMajor && colMajor) { + // it is row major and we want col major, use breeze to remove explicit zeros + val breezeTransposed = asBreeze.asInstanceOf[BSM[Double]] + Matrices.fromBreeze(breezeTransposed).asInstanceOf[SparseMatrix] + } else { + val nnz = numNonzeros + if (nnz != numActives) { + // remove explicit zeros + val rr = new Array[Int](nnz) + val vv = new Array[Double](nnz) + val numPtrs = if (isRowMajor) numRows else numCols + val cc = new Array[Int](numPtrs + 1) + var nzIdx = 0 + var j = 0 + while (j < numPtrs) { + var idx = colPtrs(j) + val idxEnd = colPtrs(j + 1) + cc(j) = nzIdx + while (idx < idxEnd) { + if (values(idx) != 0.0) { + vv(nzIdx) = values(idx) + rr(nzIdx) = rowIndices(idx) + nzIdx += 1 + } + idx += 1 + } + j += 1 + } + cc(j) = nnz + new SparseMatrix(numRows, numCols, cc, rr, vv, isTransposed = isTransposed) + } else { + this + } + } } - override def numNonzeros: Int = values.count(_ != 0) - - override def numActives: Int = values.length + /** + * Generate a `DenseMatrix` from the given `SparseMatrix`. + * + * @param colMajor Whether the resulting `DenseMatrix` values are in column major order. + */ + private[ml] override def toDenseMatrix(colMajor: Boolean): DenseMatrix = { + if (colMajor) new DenseMatrix(numRows, numCols, this.toArray) + else new DenseMatrix(numRows, numCols, this.transpose.toArray, isTransposed = true) + } override def colIter: Iterator[Vector] = { if (isTransposed) { @@ -631,6 +817,8 @@ class SparseMatrix @Since("2.0.0") ( } } } + + private[ml] def getSizeInBytes: Long = Matrices.getSparseSize(numActives, colPtrs.length) } /** @@ -1079,4 +1267,26 @@ object Matrices { SparseMatrix.fromCOO(numRows, numCols, entries) } } + + private[ml] def getSparseSize(numActives: Long, numPtrs: Long): Long = { + /* + Sparse matrices store two int arrays, one double array, two ints, and one boolean: + 8 * values.length + 4 * rowIndices.length + 4 * colPtrs.length + arrayHeader * 3 + 2 * 4 + 1 + */ + val doubleBytes = java.lang.Double.BYTES + val intBytes = java.lang.Integer.BYTES + val arrayHeader = 12L + doubleBytes * numActives + intBytes * numActives + intBytes * numPtrs + arrayHeader * 3L + 9L + } + + private[ml] def getDenseSize(numCols: Long, numRows: Long): Long = { + /* + Dense matrices store one double array, two ints, and one boolean: + 8 * values.length + arrayHeader + 2 * 4 + 1 + */ + val doubleBytes = java.lang.Double.BYTES + val arrayHeader = 12L + doubleBytes * numCols * numRows + arrayHeader + 9L + } + } diff --git a/mllib-local/src/test/scala/org/apache/spark/ml/linalg/MatricesSuite.scala b/mllib-local/src/test/scala/org/apache/spark/ml/linalg/MatricesSuite.scala index 9c0aa7393847..9f8202086817 100644 --- a/mllib-local/src/test/scala/org/apache/spark/ml/linalg/MatricesSuite.scala +++ b/mllib-local/src/test/scala/org/apache/spark/ml/linalg/MatricesSuite.scala @@ -160,22 +160,416 @@ class MatricesSuite extends SparkMLFunSuite { assert(sparseMat.values(2) === 10.0) } - test("toSparse, toDense") { - val m = 3 - val n = 2 - val values = Array(1.0, 2.0, 4.0, 5.0) - val allValues = Array(1.0, 2.0, 0.0, 0.0, 4.0, 5.0) - val colPtrs = Array(0, 2, 4) - val rowIndices = Array(0, 1, 1, 2) + test("dense to dense") { + /* + dm1 = 4.0 2.0 -8.0 + -1.0 7.0 4.0 + + dm2 = 5.0 -9.0 4.0 + 1.0 -3.0 -8.0 + */ + val dm1 = new DenseMatrix(2, 3, Array(4.0, -1.0, 2.0, 7.0, -8.0, 4.0)) + val dm2 = new DenseMatrix(2, 3, Array(5.0, -9.0, 4.0, 1.0, -3.0, -8.0), isTransposed = true) + + val dm8 = dm1.toDenseColMajor + assert(dm8 === dm1) + assert(dm8.isColMajor) + assert(dm8.values.equals(dm1.values)) + + val dm5 = dm2.toDenseColMajor + assert(dm5 === dm2) + assert(dm5.isColMajor) + assert(dm5.values === Array(5.0, 1.0, -9.0, -3.0, 4.0, -8.0)) + + val dm4 = dm1.toDenseRowMajor + assert(dm4 === dm1) + assert(dm4.isRowMajor) + assert(dm4.values === Array(4.0, 2.0, -8.0, -1.0, 7.0, 4.0)) + + val dm6 = dm2.toDenseRowMajor + assert(dm6 === dm2) + assert(dm6.isRowMajor) + assert(dm6.values.equals(dm2.values)) + + val dm3 = dm1.toDense + assert(dm3 === dm1) + assert(dm3.isColMajor) + assert(dm3.values.equals(dm1.values)) + + val dm9 = dm2.toDense + assert(dm9 === dm2) + assert(dm9.isRowMajor) + assert(dm9.values.equals(dm2.values)) + } - val spMat1 = new SparseMatrix(m, n, colPtrs, rowIndices, values) - val deMat1 = new DenseMatrix(m, n, allValues) + test("dense to sparse") { + /* + dm1 = 0.0 4.0 5.0 + 0.0 2.0 0.0 + + dm2 = 0.0 4.0 5.0 + 0.0 2.0 0.0 - val spMat2 = deMat1.toSparse - val deMat2 = spMat1.toDense + dm3 = 0.0 0.0 0.0 + 0.0 0.0 0.0 + */ + val dm1 = new DenseMatrix(2, 3, Array(0.0, 0.0, 4.0, 2.0, 5.0, 0.0)) + val dm2 = new DenseMatrix(2, 3, Array(0.0, 4.0, 5.0, 0.0, 2.0, 0.0), isTransposed = true) + val dm3 = new DenseMatrix(2, 3, Array(0.0, 0.0, 0.0, 0.0, 0.0, 0.0)) + + val sm1 = dm1.toSparseColMajor + assert(sm1 === dm1) + assert(sm1.isColMajor) + assert(sm1.values === Array(4.0, 2.0, 5.0)) + + val sm3 = dm2.toSparseColMajor + assert(sm3 === dm2) + assert(sm3.isColMajor) + assert(sm3.values === Array(4.0, 2.0, 5.0)) + + val sm5 = dm3.toSparseColMajor + assert(sm5 === dm3) + assert(sm5.values === Array.empty[Double]) + assert(sm5.isColMajor) + + val sm2 = dm1.toSparseRowMajor + assert(sm2 === dm1) + assert(sm2.isRowMajor) + assert(sm2.values === Array(4.0, 5.0, 2.0)) + + val sm4 = dm2.toSparseRowMajor + assert(sm4 === dm2) + assert(sm4.isRowMajor) + assert(sm4.values === Array(4.0, 5.0, 2.0)) + + val sm6 = dm3.toSparseRowMajor + assert(sm6 === dm3) + assert(sm6.values === Array.empty[Double]) + assert(sm6.isRowMajor) + + val sm7 = dm1.toSparse + assert(sm7 === dm1) + assert(sm7.values === Array(4.0, 2.0, 5.0)) + assert(sm7.isColMajor) + + val sm10 = dm2.toSparse + assert(sm10 === dm2) + assert(sm10.values === Array(4.0, 5.0, 2.0)) + assert(sm10.isRowMajor) + } + + test("sparse to sparse") { + /* + sm1 = sm2 = sm3 = sm4 = 0.0 4.0 5.0 + 0.0 2.0 0.0 + smZeros = 0.0 0.0 0.0 + 0.0 0.0 0.0 + */ + val sm1 = new SparseMatrix(2, 3, Array(0, 0, 2, 3), Array(0, 1, 0), Array(4.0, 2.0, 5.0)) + val sm2 = new SparseMatrix(2, 3, Array(0, 2, 3), Array(1, 2, 1), Array(4.0, 5.0, 2.0), + isTransposed = true) + val sm3 = new SparseMatrix(2, 3, Array(0, 0, 2, 4), Array(0, 1, 0, 1), + Array(4.0, 2.0, 5.0, 0.0)) + val sm4 = new SparseMatrix(2, 3, Array(0, 2, 4), Array(1, 2, 1, 2), + Array(4.0, 5.0, 2.0, 0.0), isTransposed = true) + val smZeros = new SparseMatrix(2, 3, Array(0, 2, 4, 6), Array(0, 1, 0, 1, 0, 1), + Array(0.0, 0.0, 0.0, 0.0, 0.0, 0.0)) + + val sm6 = sm1.toSparseColMajor + assert(sm6 === sm1) + assert(sm6.isColMajor) + assert(sm6.values.equals(sm1.values)) + + val sm7 = sm2.toSparseColMajor + assert(sm7 === sm2) + assert(sm7.isColMajor) + assert(sm7.values === Array(4.0, 2.0, 5.0)) + + val sm16 = sm3.toSparseColMajor + assert(sm16 === sm3) + assert(sm16.isColMajor) + assert(sm16.values === Array(4.0, 2.0, 5.0)) + + val sm14 = sm4.toSparseColMajor + assert(sm14 === sm4) + assert(sm14.values === Array(4.0, 2.0, 5.0)) + assert(sm14.isColMajor) + + val sm15 = smZeros.toSparseColMajor + assert(sm15 === smZeros) + assert(sm15.values === Array.empty[Double]) + assert(sm15.isColMajor) + + val sm5 = sm1.toSparseRowMajor + assert(sm5 === sm1) + assert(sm5.isRowMajor) + assert(sm5.values === Array(4.0, 5.0, 2.0)) + + val sm8 = sm2.toSparseRowMajor + assert(sm8 === sm2) + assert(sm8.isRowMajor) + assert(sm8.values.equals(sm2.values)) + + val sm10 = sm3.toSparseRowMajor + assert(sm10 === sm3) + assert(sm10.values === Array(4.0, 5.0, 2.0)) + assert(sm10.isRowMajor) + + val sm11 = sm4.toSparseRowMajor + assert(sm11 === sm4) + assert(sm11.values === Array(4.0, 5.0, 2.0)) + assert(sm11.isRowMajor) + + val sm17 = smZeros.toSparseRowMajor + assert(sm17 === smZeros) + assert(sm17.values === Array.empty[Double]) + assert(sm17.isRowMajor) + + val sm9 = sm3.toSparse + assert(sm9 === sm3) + assert(sm9.values === Array(4.0, 2.0, 5.0)) + assert(sm9.isColMajor) + + val sm12 = sm4.toSparse + assert(sm12 === sm4) + assert(sm12.values === Array(4.0, 5.0, 2.0)) + assert(sm12.isRowMajor) + + val sm13 = smZeros.toSparse + assert(sm13 === smZeros) + assert(sm13.values === Array.empty[Double]) + assert(sm13.isColMajor) + } + + test("sparse to dense") { + /* + sm1 = sm2 = 0.0 4.0 5.0 + 0.0 2.0 0.0 + + sm3 = 0.0 0.0 0.0 + 0.0 0.0 0.0 + */ + val sm1 = new SparseMatrix(2, 3, Array(0, 0, 2, 3), Array(0, 1, 0), Array(4.0, 2.0, 5.0)) + val sm2 = new SparseMatrix(2, 3, Array(0, 2, 3), Array(1, 2, 1), Array(4.0, 5.0, 2.0), + isTransposed = true) + val sm3 = new SparseMatrix(2, 3, Array(0, 0, 0, 0), Array.empty[Int], Array.empty[Double]) + + val dm6 = sm1.toDenseColMajor + assert(dm6 === sm1) + assert(dm6.isColMajor) + assert(dm6.values === Array(0.0, 0.0, 4.0, 2.0, 5.0, 0.0)) + + val dm7 = sm2.toDenseColMajor + assert(dm7 === sm2) + assert(dm7.isColMajor) + assert(dm7.values === Array(0.0, 0.0, 4.0, 2.0, 5.0, 0.0)) + + val dm2 = sm1.toDenseRowMajor + assert(dm2 === sm1) + assert(dm2.isRowMajor) + assert(dm2.values === Array(0.0, 4.0, 5.0, 0.0, 2.0, 0.0)) + + val dm4 = sm2.toDenseRowMajor + assert(dm4 === sm2) + assert(dm4.isRowMajor) + assert(dm4.values === Array(0.0, 4.0, 5.0, 0.0, 2.0, 0.0)) + + val dm1 = sm1.toDense + assert(dm1 === sm1) + assert(dm1.isColMajor) + assert(dm1.values === Array(0.0, 0.0, 4.0, 2.0, 5.0, 0.0)) + + val dm3 = sm2.toDense + assert(dm3 === sm2) + assert(dm3.isRowMajor) + assert(dm3.values === Array(0.0, 4.0, 5.0, 0.0, 2.0, 0.0)) + + val dm5 = sm3.toDense + assert(dm5 === sm3) + assert(dm5.isColMajor) + assert(dm5.values === Array.fill(6)(0.0)) + } + + test("compressed dense") { + /* + dm1 = 1.0 0.0 0.0 0.0 + 1.0 0.0 0.0 0.0 + 0.0 0.0 0.0 0.0 + + dm2 = 1.0 1.0 0.0 0.0 + 0.0 0.0 0.0 0.0 + 0.0 0.0 0.0 0.0 + */ + // this should compress to a sparse matrix + val dm1 = new DenseMatrix(3, 4, Array.fill(2)(1.0) ++ Array.fill(10)(0.0)) + + // optimal compression layout is row major since numRows < numCols + val cm1 = dm1.compressed.asInstanceOf[SparseMatrix] + assert(cm1 === dm1) + assert(cm1.isRowMajor) + assert(cm1.getSizeInBytes < dm1.getSizeInBytes) + + // force compressed column major + val cm2 = dm1.compressedColMajor.asInstanceOf[SparseMatrix] + assert(cm2 === dm1) + assert(cm2.isColMajor) + assert(cm2.getSizeInBytes < dm1.getSizeInBytes) + + // optimal compression layout for transpose is column major + val dm2 = dm1.transpose + val cm3 = dm2.compressed.asInstanceOf[SparseMatrix] + assert(cm3 === dm2) + assert(cm3.isColMajor) + assert(cm3.getSizeInBytes < dm2.getSizeInBytes) + + /* + dm3 = 1.0 1.0 1.0 0.0 + 1.0 1.0 0.0 0.0 + 1.0 1.0 0.0 0.0 + + dm4 = 1.0 1.0 1.0 1.0 + 1.0 1.0 1.0 0.0 + 0.0 0.0 0.0 0.0 + */ + // this should compress to a dense matrix + val dm3 = new DenseMatrix(3, 4, Array.fill(7)(1.0) ++ Array.fill(5)(0.0)) + val dm4 = new DenseMatrix(3, 4, Array.fill(7)(1.0) ++ Array.fill(5)(0.0), isTransposed = true) + + val cm4 = dm3.compressed.asInstanceOf[DenseMatrix] + assert(cm4 === dm3) + assert(cm4.isColMajor) + assert(cm4.values.equals(dm3.values)) + assert(cm4.getSizeInBytes === dm3.getSizeInBytes) + + // force compressed row major + val cm5 = dm3.compressedRowMajor.asInstanceOf[DenseMatrix] + assert(cm5 === dm3) + assert(cm5.isRowMajor) + assert(cm5.getSizeInBytes === dm3.getSizeInBytes) + + val cm6 = dm4.compressed.asInstanceOf[DenseMatrix] + assert(cm6 === dm4) + assert(cm6.isRowMajor) + assert(cm6.values.equals(dm4.values)) + assert(cm6.getSizeInBytes === dm4.getSizeInBytes) + + val cm7 = dm4.compressedColMajor.asInstanceOf[DenseMatrix] + assert(cm7 === dm4) + assert(cm7.isColMajor) + assert(cm7.getSizeInBytes === dm4.getSizeInBytes) + + // this has the same size sparse or dense + val dm5 = new DenseMatrix(4, 4, Array.fill(7)(1.0) ++ Array.fill(9)(0.0)) + // should choose dense to break ties + val cm8 = dm5.compressed.asInstanceOf[DenseMatrix] + assert(cm8.getSizeInBytes === dm5.toSparseColMajor.getSizeInBytes) + } - assert(spMat1.asBreeze === spMat2.asBreeze) - assert(deMat1.asBreeze === deMat2.asBreeze) + test("compressed sparse") { + /* + sm1 = 0.0 -1.0 + 0.0 0.0 + 0.0 0.0 + 0.0 0.0 + + sm2 = 0.0 0.0 0.0 0.0 + -1.0 0.0 0.0 0.0 + */ + // these should compress to sparse matrices + val sm1 = new SparseMatrix(4, 2, Array(0, 0, 1), Array(0), Array(-1.0)) + val sm2 = sm1.transpose + + val cm1 = sm1.compressed.asInstanceOf[SparseMatrix] + // optimal is column major + assert(cm1 === sm1) + assert(cm1.isColMajor) + assert(cm1.values.equals(sm1.values)) + assert(cm1.getSizeInBytes === sm1.getSizeInBytes) + + val cm2 = sm1.compressedRowMajor.asInstanceOf[SparseMatrix] + assert(cm2 === sm1) + assert(cm2.isRowMajor) + // forced to be row major, so we have increased the size + assert(cm2.getSizeInBytes > sm1.getSizeInBytes) + assert(cm2.getSizeInBytes < sm1.toDense.getSizeInBytes) + + val cm9 = sm1.compressedColMajor.asInstanceOf[SparseMatrix] + assert(cm9 === sm1) + assert(cm9.values.equals(sm1.values)) + assert(cm9.getSizeInBytes === sm1.getSizeInBytes) + + val cm3 = sm2.compressed.asInstanceOf[SparseMatrix] + assert(cm3 === sm2) + assert(cm3.isRowMajor) + assert(cm3.values.equals(sm2.values)) + assert(cm3.getSizeInBytes === sm2.getSizeInBytes) + + val cm8 = sm2.compressedColMajor.asInstanceOf[SparseMatrix] + assert(cm8 === sm2) + assert(cm8.isColMajor) + // forced to be col major, so we have increased the size + assert(cm8.getSizeInBytes > sm2.getSizeInBytes) + assert(cm8.getSizeInBytes < sm2.toDense.getSizeInBytes) + + val cm10 = sm2.compressedRowMajor.asInstanceOf[SparseMatrix] + assert(cm10 === sm2) + assert(cm10.isRowMajor) + assert(cm10.values.equals(sm2.values)) + assert(cm10.getSizeInBytes === sm2.getSizeInBytes) + + + /* + sm3 = 0.0 -1.0 + 2.0 3.0 + -4.0 9.0 + */ + // this should compress to a dense matrix + val sm3 = new SparseMatrix(3, 2, Array(0, 2, 5), Array(1, 2, 0, 1, 2), + Array(2.0, -4.0, -1.0, 3.0, 9.0)) + + // dense is optimal, and maintains column major + val cm4 = sm3.compressed.asInstanceOf[DenseMatrix] + assert(cm4 === sm3) + assert(cm4.isColMajor) + assert(cm4.getSizeInBytes < sm3.getSizeInBytes) + + val cm5 = sm3.compressedRowMajor.asInstanceOf[DenseMatrix] + assert(cm5 === sm3) + assert(cm5.isRowMajor) + assert(cm5.getSizeInBytes < sm3.getSizeInBytes) + + val cm11 = sm3.compressedColMajor.asInstanceOf[DenseMatrix] + assert(cm11 === sm3) + assert(cm11.isColMajor) + assert(cm11.getSizeInBytes < sm3.getSizeInBytes) + + /* + sm4 = 1.0 0.0 0.0 ... + + sm5 = 1.0 + 0.0 + 0.0 + ... + */ + val sm4 = new SparseMatrix(Int.MaxValue, 1, Array(0, 1), Array(0), Array(1.0)) + val cm6 = sm4.compressed.asInstanceOf[SparseMatrix] + assert(cm6 === sm4) + assert(cm6.isColMajor) + assert(cm6.getSizeInBytes <= sm4.getSizeInBytes) + + val sm5 = new SparseMatrix(1, Int.MaxValue, Array(0, 1), Array(0), Array(1.0), + isTransposed = true) + val cm7 = sm5.compressed.asInstanceOf[SparseMatrix] + assert(cm7 === sm5) + assert(cm7.isRowMajor) + assert(cm7.getSizeInBytes <= sm5.getSizeInBytes) + + // this has the same size sparse or dense + val sm6 = new SparseMatrix(4, 4, Array(0, 4, 7, 7, 7), Array(0, 1, 2, 3, 0, 1, 2), + Array.fill(7)(1.0)) + // should choose dense to break ties + val cm12 = sm6.compressed.asInstanceOf[DenseMatrix] + assert(cm12.getSizeInBytes === sm6.getSizeInBytes) } test("map, update") { diff --git a/mllib-local/src/test/scala/org/apache/spark/ml/linalg/VectorsSuite.scala b/mllib-local/src/test/scala/org/apache/spark/ml/linalg/VectorsSuite.scala index ea22c2787fb3..dfbdaf19d374 100644 --- a/mllib-local/src/test/scala/org/apache/spark/ml/linalg/VectorsSuite.scala +++ b/mllib-local/src/test/scala/org/apache/spark/ml/linalg/VectorsSuite.scala @@ -336,6 +336,11 @@ class VectorsSuite extends SparkMLFunSuite { val sv1 = Vectors.sparse(4, Array(0, 1, 2), Array(1.0, 2.0, 3.0)) val sv1c = sv1.compressed.asInstanceOf[DenseVector] assert(sv1 === sv1c) + + val sv2 = Vectors.sparse(Int.MaxValue, Array(0), Array(3.4)) + val sv2c = sv2.compressed.asInstanceOf[SparseVector] + assert(sv2c === sv2) + assert(sv2c.numActives === 1) } test("SparseVector.slice") { diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 8ce9367c9b44..2e3f9f2d0f3a 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -81,7 +81,25 @@ object MimaExcludes { // [SPARK-19876] Add one time trigger, and improve Trigger APIs ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.sql.streaming.Trigger"), - ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.sql.streaming.ProcessingTime") + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.sql.streaming.ProcessingTime"), + + // [SPARK-17471][ML] Add compressed method to ML matrices + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Matrix.compressed"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Matrix.compressedColMajor"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Matrix.compressedRowMajor"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Matrix.isRowMajor"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Matrix.isColMajor"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Matrix.getSparseSizeInBytes"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Matrix.toDense"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Matrix.toSparse"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Matrix.toDenseRowMajor"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Matrix.toSparseRowMajor"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Matrix.toSparseColMajor"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Matrix.getDenseSizeInBytes"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Matrix.toDenseColMajor"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Matrix.toDenseMatrix"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Matrix.toSparseMatrix"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Matrix.getSizeInBytes") ) // Exclude rules for 2.1.x From 91fa80fe8a2480d64c430bd10f97b3d44c007bcc Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Fri, 24 Mar 2017 15:52:48 -0700 Subject: [PATCH 124/512] [SPARK-20070][SQL] Redact DataSourceScanExec treeString ## What changes were proposed in this pull request? The explain output of `DataSourceScanExec` can contain sensitive information (like Amazon keys). Such information should not end up in logs, or be exposed to non privileged users. This PR addresses this by adding a redaction facility for the `DataSourceScanExec.treeString`. A user can enable this by setting a regex in the `spark.redaction.string.regex` configuration. ## How was this patch tested? Added a unit test to check the output of DataSourceScanExec. Author: Herman van Hovell Closes #17397 from hvanhovell/SPARK-20070. --- .../spark/internal/config/ConfigBuilder.scala | 13 ++++ .../spark/internal/config/package.scala | 12 +++- .../scala/org/apache/spark/util/Utils.scala | 17 +++++- .../internal/config/ConfigEntrySuite.scala | 19 ++++-- .../sql/execution/DataSourceScanExec.scala | 41 ++++++++----- .../DataSourceScanExecRedactionSuite.scala | 60 +++++++++++++++++++ 6 files changed, 138 insertions(+), 24 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/DataSourceScanExecRedactionSuite.scala diff --git a/core/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala b/core/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala index a177e66645c7..d87619afd3b2 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala @@ -18,6 +18,9 @@ package org.apache.spark.internal.config import java.util.concurrent.TimeUnit +import java.util.regex.PatternSyntaxException + +import scala.util.matching.Regex import org.apache.spark.network.util.{ByteUnit, JavaUtils} @@ -65,6 +68,13 @@ private object ConfigHelpers { def byteToString(v: Long, unit: ByteUnit): String = unit.convertTo(v, ByteUnit.BYTE) + "b" + def regexFromString(str: String, key: String): Regex = { + try str.r catch { + case e: PatternSyntaxException => + throw new IllegalArgumentException(s"$key should be a regex, but was $str", e) + } + } + } /** @@ -214,4 +224,7 @@ private[spark] case class ConfigBuilder(key: String) { new FallbackConfigEntry(key, _doc, _public, fallback) } + def regexConf: TypedConfigBuilder[Regex] = { + new TypedConfigBuilder(this, regexFromString(_, this.key), _.regex) + } } diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 223c92181037..89aeea493908 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -246,8 +246,16 @@ package object config { "driver and executor environments contain sensitive information. When this regex matches " + "a property, its value is redacted from the environment UI and various logs like YARN " + "and event logs.") - .stringConf - .createWithDefault("(?i)secret|password") + .regexConf + .createWithDefault("(?i)secret|password".r) + + private[spark] val STRING_REDACTION_PATTERN = + ConfigBuilder("spark.redaction.string.regex") + .doc("Regex to decide which parts of strings produced by Spark contain sensitive " + + "information. When this regex matches a string part, that string part is replaced by a " + + "dummy value. This is currently used to redact the output of SQL explain commands.") + .regexConf + .createOptional private[spark] val NETWORK_AUTH_ENABLED = ConfigBuilder("spark.authenticate") diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 1af34e3da231..943dde072327 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -2585,13 +2585,26 @@ private[spark] object Utils extends Logging { } } - private[util] val REDACTION_REPLACEMENT_TEXT = "*********(redacted)" + private[spark] val REDACTION_REPLACEMENT_TEXT = "*********(redacted)" + /** + * Redact the sensitive values in the given map. If a map key matches the redaction pattern then + * its value is replaced with a dummy text. + */ def redact(conf: SparkConf, kvs: Seq[(String, String)]): Seq[(String, String)] = { - val redactionPattern = conf.get(SECRET_REDACTION_PATTERN).r + val redactionPattern = conf.get(SECRET_REDACTION_PATTERN) redact(redactionPattern, kvs) } + /** + * Redact the sensitive information in the given string. + */ + def redact(conf: SparkConf, text: String): String = { + if (text == null || text.isEmpty || !conf.contains(STRING_REDACTION_PATTERN)) return text + val regex = conf.get(STRING_REDACTION_PATTERN).get + regex.replaceAllIn(text, REDACTION_REPLACEMENT_TEXT) + } + private def redact(redactionPattern: Regex, kvs: Seq[(String, String)]): Seq[(String, String)] = { kvs.map { kv => redactionPattern.findFirstIn(kv._1) diff --git a/core/src/test/scala/org/apache/spark/internal/config/ConfigEntrySuite.scala b/core/src/test/scala/org/apache/spark/internal/config/ConfigEntrySuite.scala index 71eed464880b..f3756b21080b 100644 --- a/core/src/test/scala/org/apache/spark/internal/config/ConfigEntrySuite.scala +++ b/core/src/test/scala/org/apache/spark/internal/config/ConfigEntrySuite.scala @@ -19,9 +19,6 @@ package org.apache.spark.internal.config import java.util.concurrent.TimeUnit -import scala.collection.JavaConverters._ -import scala.collection.mutable.HashMap - import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.network.util.ByteUnit import org.apache.spark.util.SparkConfWithEnv @@ -98,6 +95,21 @@ class ConfigEntrySuite extends SparkFunSuite { assert(conf.get(bytes) === 1L) } + test("conf entry: regex") { + val conf = new SparkConf() + val rConf = ConfigBuilder(testKey("regex")).regexConf.createWithDefault(".*".r) + + conf.set(rConf, "[0-9a-f]{8}".r) + assert(conf.get(rConf).regex === "[0-9a-f]{8}") + + conf.set(rConf.key, "[0-9a-f]{4}") + assert(conf.get(rConf).regex === "[0-9a-f]{4}") + + conf.set(rConf.key, "[.") + val e = intercept[IllegalArgumentException](conf.get(rConf)) + assert(e.getMessage.contains("regex should be a regex, but was")) + } + test("conf entry: string seq") { val conf = new SparkConf() val seq = ConfigBuilder(testKey("seq")).stringConf.toSequence.createWithDefault(Seq()) @@ -239,5 +251,4 @@ class ConfigEntrySuite extends SparkFunSuite { .createWithDefault(null) testEntryRef(nullConf, ref(nullConf)) } - } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index bfe9c8e351ab..28156b277f59 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -41,9 +41,33 @@ trait DataSourceScanExec extends LeafExecNode with CodegenSupport { val relation: BaseRelation val metastoreTableIdentifier: Option[TableIdentifier] + protected val nodeNamePrefix: String = "" + override val nodeName: String = { s"Scan $relation ${metastoreTableIdentifier.map(_.unquotedString).getOrElse("")}" } + + override def simpleString: String = { + val metadataEntries = metadata.toSeq.sorted.map { + case (key, value) => + key + ": " + StringUtils.abbreviate(redact(value), 100) + } + val metadataStr = Utils.truncatedString(metadataEntries, " ", ", ", "") + s"$nodeNamePrefix$nodeName${Utils.truncatedString(output, "[", ",", "]")}$metadataStr" + } + + override def verboseString: String = redact(super.verboseString) + + override def treeString(verbose: Boolean, addSuffix: Boolean): String = { + redact(super.treeString(verbose, addSuffix)) + } + + /** + * Shorthand for calling redactString() without specifying redacting rules + */ + private def redact(text: String): String = { + Utils.redact(SparkSession.getActiveSession.get.sparkContext.conf, text) + } } /** Physical plan node for scanning data from a relation. */ @@ -85,15 +109,6 @@ case class RowDataSourceScanExec( } } - override def simpleString: String = { - val metadataEntries = for ((key, value) <- metadata.toSeq.sorted) yield { - key + ": " + StringUtils.abbreviate(value, 100) - } - - s"$nodeName${Utils.truncatedString(output, "[", ",", "]")}" + - s"${Utils.truncatedString(metadataEntries, " ", ", ", "")}" - } - override def inputRDDs(): Seq[RDD[InternalRow]] = { rdd :: Nil } @@ -307,13 +322,7 @@ case class FileSourceScanExec( } } - override def simpleString: String = { - val metadataEntries = for ((key, value) <- metadata.toSeq.sorted) yield { - key + ": " + StringUtils.abbreviate(value, 100) - } - val metadataStr = Utils.truncatedString(metadataEntries, " ", ", ", "") - s"File$nodeName${Utils.truncatedString(output, "[", ",", "]")}$metadataStr" - } + override val nodeNamePrefix: String = "File" override protected def doProduce(ctx: CodegenContext): String = { if (supportsBatch) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/DataSourceScanExecRedactionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/DataSourceScanExecRedactionSuite.scala new file mode 100644 index 000000000000..986fa878ee29 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/DataSourceScanExecRedactionSuite.scala @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution + +import org.apache.hadoop.fs.Path + +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.util.Utils + +/** + * Suite that tests the redaction of DataSourceScanExec + */ +class DataSourceScanExecRedactionSuite extends QueryTest with SharedSQLContext { + + import Utils._ + + override def beforeAll(): Unit = { + sparkConf.set("spark.redaction.string.regex", + "spark-[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}") + super.beforeAll() + } + + test("treeString is redacted") { + withTempDir { dir => + val basePath = dir.getCanonicalPath + spark.range(0, 10).toDF("a").write.parquet(new Path(basePath, "foo=1").toString) + val df = spark.read.parquet(basePath) + + val rootPath = df.queryExecution.sparkPlan.find(_.isInstanceOf[FileSourceScanExec]).get + .asInstanceOf[FileSourceScanExec].relation.location.rootPaths.head + assert(rootPath.toString.contains(basePath.toString)) + + assert(!df.queryExecution.sparkPlan.treeString(verbose = true).contains(rootPath.getName)) + assert(!df.queryExecution.executedPlan.treeString(verbose = true).contains(rootPath.getName)) + assert(!df.queryExecution.toString.contains(rootPath.getName)) + assert(!df.queryExecution.simpleString.contains(rootPath.getName)) + + val replacement = "*********" + assert(df.queryExecution.sparkPlan.treeString(verbose = true).contains(replacement)) + assert(df.queryExecution.executedPlan.treeString(verbose = true).contains(replacement)) + assert(df.queryExecution.toString.contains(replacement)) + assert(df.queryExecution.simpleString.contains(replacement)) + } + } +} From b5c5bd98ea5e8dbfebcf86c5459bdf765f5ceb53 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Fri, 24 Mar 2017 23:57:29 +0100 Subject: [PATCH 125/512] Disable generate codegen since it fails my workload. --- .../spark/sql/execution/GenerateExec.scala | 2 +- .../execution/WholeStageCodegenSuite.scala | 28 ------------------- 2 files changed, 1 insertion(+), 29 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala index 69be7094d2c3..f87d05884b27 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala @@ -119,7 +119,7 @@ case class GenerateExec( } } - override def supportCodegen: Boolean = generator.supportCodegen + override def supportCodegen: Boolean = false override def inputRDDs(): Seq[RDD[InternalRow]] = { child.asInstanceOf[CodegenSupport].inputRDDs() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index 4d9203556d49..a4b30a2f8cec 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -116,34 +116,6 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext { assert(ds.collect() === Array(("a", 10.0), ("b", 3.0), ("c", 1.0))) } - test("generate should be included in WholeStageCodegen") { - import org.apache.spark.sql.functions._ - val ds = spark.range(2).select( - col("id"), - explode(array(col("id") + 1, col("id") + 2)).as("value")) - val plan = ds.queryExecution.executedPlan - assert(plan.find(p => - p.isInstanceOf[WholeStageCodegenExec] && - p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[GenerateExec]).isDefined) - assert(ds.collect() === Array(Row(0, 1), Row(0, 2), Row(1, 2), Row(1, 3))) - } - - test("large stack generator should not use WholeStageCodegen") { - def createStackGenerator(rows: Int): SparkPlan = { - val id = UnresolvedAttribute("id") - val stack = Stack(Literal(rows) +: Seq.tabulate(rows)(i => Add(id, Literal(i)))) - spark.range(500).select(Column(stack)).queryExecution.executedPlan - } - val isCodeGenerated: SparkPlan => Boolean = { - case WholeStageCodegenExec(_: GenerateExec) => true - case _ => false - } - - // Only 'stack' generators that produce 50 rows or less are code generated. - assert(createStackGenerator(50).find(isCodeGenerated).isDefined) - assert(createStackGenerator(100).find(isCodeGenerated).isEmpty) - } - test("SPARK-19512 codegen for comparing structs is incorrect") { // this would raise CompileException before the fix spark.range(10) From e011004bedca47be998a0c14fe22a6f9bb5090cd Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sat, 25 Mar 2017 00:04:51 +0100 Subject: [PATCH 126/512] [SPARK-19846][SQL] Add a flag to disable constraint propagation ## What changes were proposed in this pull request? Constraint propagation can be computation expensive and block the driver execution for long time. For example, the below benchmark needs 30mins. Compared with previous PRs #16998, #16785, this is a much simpler option: add a flag to disable constraint propagation. ### Benchmark Run the following codes locally. import org.apache.spark.ml.{Pipeline, PipelineStage} import org.apache.spark.ml.feature.{OneHotEncoder, StringIndexer, VectorAssembler} import org.apache.spark.sql.internal.SQLConf spark.conf.set(SQLConf.CONSTRAINT_PROPAGATION_ENABLED.key, false) val df = (1 to 40).foldLeft(Seq((1, "foo"), (2, "bar"), (3, "baz")).toDF("id", "x0"))((df, i) => df.withColumn(s"x$i", $"x0")) val indexers = df.columns.tail.map(c => new StringIndexer() .setInputCol(c) .setOutputCol(s"${c}_indexed") .setHandleInvalid("skip")) val encoders = indexers.map(indexer => new OneHotEncoder() .setInputCol(indexer.getOutputCol) .setOutputCol(s"${indexer.getOutputCol}_encoded") .setDropLast(true)) val stages: Array[PipelineStage] = indexers ++ encoders val pipeline = new Pipeline().setStages(stages) val startTime = System.nanoTime pipeline.fit(df).transform(df).show val runningTime = System.nanoTime - startTime Before this patch: 1786001 ms ~= 30 mins After this patch: 26392 ms = less than half of a minute Related PRs: #16998, #16785. ## How was this patch tested? Jenkins tests. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Liang-Chi Hsieh Closes #17186 from viirya/add-flag-disable-constraint-propagation. --- .../sql/catalyst/SimpleCatalystConf.scala | 3 +- .../sql/catalyst/optimizer/Optimizer.scala | 22 ++++++---- .../spark/sql/catalyst/optimizer/joins.scala | 6 ++- .../spark/sql/catalyst/plans/QueryPlan.scala | 11 +++++ .../apache/spark/sql/internal/SQLConf.scala | 11 +++++ .../BinaryComparisonSimplificationSuite.scala | 5 ++- .../BooleanSimplificationSuite.scala | 5 ++- .../InferFiltersFromConstraintsSuite.scala | 19 ++++++++- .../optimizer/OuterJoinEliminationSuite.scala | 30 +++++++++++++- .../PropagateEmptyRelationSuite.scala | 5 ++- .../optimizer/PruneFiltersSuite.scala | 40 ++++++++++++++++++- .../optimizer/SetOperationSuite.scala | 3 +- .../plans/ConstraintPropagationSuite.scala | 18 +++++++++ 13 files changed, 158 insertions(+), 20 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SimpleCatalystConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SimpleCatalystConf.scala index ac97987c55e0..8498cf1c9be7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SimpleCatalystConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SimpleCatalystConf.scala @@ -43,7 +43,8 @@ case class SimpleCatalystConf( override val starSchemaDetection: Boolean = false, override val warehousePath: String = "/user/hive/warehouse", override val sessionLocalTimeZone: String = TimeZone.getDefault().getID, - override val maxNestedViewDepth: Int = 100) + override val maxNestedViewDepth: Int = 100, + override val constraintPropagationEnabled: Boolean = true) extends SQLConf { override def clone(): SimpleCatalystConf = this.copy() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index d7524a57adbc..ee7de8692149 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -83,12 +83,12 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf) // Operator push down PushProjectionThroughUnion, ReorderJoin(conf), - EliminateOuterJoin, + EliminateOuterJoin(conf), PushPredicateThroughJoin, PushDownPredicate, LimitPushDown(conf), ColumnPruning, - InferFiltersFromConstraints, + InferFiltersFromConstraints(conf), // Operator combine CollapseRepartition, CollapseProject, @@ -107,7 +107,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf) SimplifyConditionals, RemoveDispensableExpressions, SimplifyBinaryComparison, - PruneFilters, + PruneFilters(conf), EliminateSorts, SimplifyCasts, SimplifyCaseConversionExpressions, @@ -615,8 +615,16 @@ object CollapseWindow extends Rule[LogicalPlan] { * Note: While this optimization is applicable to all types of join, it primarily benefits Inner and * LeftSemi joins. */ -object InferFiltersFromConstraints extends Rule[LogicalPlan] with PredicateHelper { - def apply(plan: LogicalPlan): LogicalPlan = plan transform { +case class InferFiltersFromConstraints(conf: CatalystConf) + extends Rule[LogicalPlan] with PredicateHelper { + def apply(plan: LogicalPlan): LogicalPlan = if (conf.constraintPropagationEnabled) { + inferFilters(plan) + } else { + plan + } + + + private def inferFilters(plan: LogicalPlan): LogicalPlan = plan transform { case filter @ Filter(condition, child) => val newFilters = filter.constraints -- (child.constraints ++ splitConjunctivePredicates(condition)) @@ -705,7 +713,7 @@ object EliminateSorts extends Rule[LogicalPlan] { * 2) by substituting a dummy empty relation when the filter will always evaluate to `false`. * 3) by eliminating the always-true conditions given the constraints on the child's output. */ -object PruneFilters extends Rule[LogicalPlan] with PredicateHelper { +case class PruneFilters(conf: CatalystConf) extends Rule[LogicalPlan] with PredicateHelper { def apply(plan: LogicalPlan): LogicalPlan = plan transform { // If the filter condition always evaluate to true, remove the filter. case Filter(Literal(true, BooleanType), child) => child @@ -718,7 +726,7 @@ object PruneFilters extends Rule[LogicalPlan] with PredicateHelper { case f @ Filter(fc, p: LogicalPlan) => val (prunedPredicates, remainingPredicates) = splitConjunctivePredicates(fc).partition { cond => - cond.deterministic && p.constraints.contains(cond) + cond.deterministic && p.getConstraints(conf.constraintPropagationEnabled).contains(cond) } if (prunedPredicates.isEmpty) { f diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala index 58e4a230f4ef..5f7316566b3b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.optimizer import scala.annotation.tailrec +import org.apache.spark.sql.catalyst.CatalystConf import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning.{ExtractFiltersAndInnerJoins, PhysicalOperation} import org.apache.spark.sql.catalyst.plans._ @@ -439,7 +440,7 @@ case class ReorderJoin(conf: SQLConf) extends Rule[LogicalPlan] with PredicateHe * * This rule should be executed before pushing down the Filter */ -object EliminateOuterJoin extends Rule[LogicalPlan] with PredicateHelper { +case class EliminateOuterJoin(conf: CatalystConf) extends Rule[LogicalPlan] with PredicateHelper { /** * Returns whether the expression returns null or false when all inputs are nulls. @@ -455,7 +456,8 @@ object EliminateOuterJoin extends Rule[LogicalPlan] with PredicateHelper { } private def buildNewJoinType(filter: Filter, join: Join): JoinType = { - val conditions = splitConjunctivePredicates(filter.condition) ++ filter.constraints + val conditions = splitConjunctivePredicates(filter.condition) ++ + filter.getConstraints(conf.constraintPropagationEnabled) val leftConditions = conditions.filter(_.references.subsetOf(join.left.outputSet)) val rightConditions = conditions.filter(_.references.subsetOf(join.right.outputSet)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index a5761703fd65..9fd95a4b368c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -186,6 +186,17 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT */ lazy val constraints: ExpressionSet = ExpressionSet(getRelevantConstraints(validConstraints)) + /** + * Returns [[constraints]] depending on the config of enabling constraint propagation. If the + * flag is disabled, simply returning an empty constraints. + */ + private[spark] def getConstraints(constraintPropagationEnabled: Boolean): ExpressionSet = + if (constraintPropagationEnabled) { + constraints + } else { + ExpressionSet(Set.empty) + } + /** * This method can be overridden by any child class of QueryPlan to specify a set of constraints * based on the given operator's constraint propagation logic. These constraints are then diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index d5006c16469b..5566b06aa355 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -187,6 +187,15 @@ object SQLConf { .booleanConf .createWithDefault(false) + val CONSTRAINT_PROPAGATION_ENABLED = buildConf("spark.sql.constraintPropagation.enabled") + .internal() + .doc("When true, the query optimizer will infer and propagate data constraints in the query " + + "plan to optimize them. Constraint propagation can sometimes be computationally expensive" + + "for certain kinds of query plans (such as those with a large number of predicates and " + + "aliases) which might negatively impact overall runtime.") + .booleanConf + .createWithDefault(true) + val PARQUET_SCHEMA_MERGING_ENABLED = buildConf("spark.sql.parquet.mergeSchema") .doc("When true, the Parquet data source merges schemas collected from all data files, " + "otherwise the schema is picked from the summary file or a random data file " + @@ -887,6 +896,8 @@ class SQLConf extends Serializable with Logging { def caseSensitiveAnalysis: Boolean = getConf(SQLConf.CASE_SENSITIVE) + def constraintPropagationEnabled: Boolean = getConf(CONSTRAINT_PROPAGATION_ENABLED) + /** * Returns the [[Resolver]] for the current configuration, which can be used to determine if two * identifiers are equal. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BinaryComparisonSimplificationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BinaryComparisonSimplificationSuite.scala index a0d489681fd9..2bfddb7bc2f3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BinaryComparisonSimplificationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BinaryComparisonSimplificationSuite.scala @@ -30,15 +30,16 @@ import org.apache.spark.sql.catalyst.rules._ class BinaryComparisonSimplificationSuite extends PlanTest with PredicateHelper { object Optimize extends RuleExecutor[LogicalPlan] { + val conf = SimpleCatalystConf(caseSensitiveAnalysis = true) val batches = Batch("AnalysisNodes", Once, EliminateSubqueryAliases) :: Batch("Constant Folding", FixedPoint(50), - NullPropagation(SimpleCatalystConf(caseSensitiveAnalysis = true)), + NullPropagation(conf), ConstantFolding, BooleanSimplification, SimplifyBinaryComparison, - PruneFilters) :: Nil + PruneFilters(conf)) :: Nil } val nullableRelation = LocalRelation('a.int.withNullability(true)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala index 1b9db0601492..4d404f55aa57 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala @@ -30,14 +30,15 @@ import org.apache.spark.sql.catalyst.rules._ class BooleanSimplificationSuite extends PlanTest with PredicateHelper { object Optimize extends RuleExecutor[LogicalPlan] { + val conf = SimpleCatalystConf(caseSensitiveAnalysis = true) val batches = Batch("AnalysisNodes", Once, EliminateSubqueryAliases) :: Batch("Constant Folding", FixedPoint(50), - NullPropagation(SimpleCatalystConf(caseSensitiveAnalysis = true)), + NullPropagation(conf), ConstantFolding, BooleanSimplification, - PruneFilters) :: Nil + PruneFilters(conf)) :: Nil } val testRelation = LocalRelation('a.int, 'b.int, 'c.int, 'd.string) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala index 9f57f66a2ea2..98d8b897a916 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.optimizer +import org.apache.spark.sql.catalyst.SimpleCatalystConf import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ @@ -31,7 +32,17 @@ class InferFiltersFromConstraintsSuite extends PlanTest { Batch("InferAndPushDownFilters", FixedPoint(100), PushPredicateThroughJoin, PushDownPredicate, - InferFiltersFromConstraints, + InferFiltersFromConstraints(SimpleCatalystConf(caseSensitiveAnalysis = true)), + CombineFilters) :: Nil + } + + object OptimizeWithConstraintPropagationDisabled extends RuleExecutor[LogicalPlan] { + val batches = + Batch("InferAndPushDownFilters", FixedPoint(100), + PushPredicateThroughJoin, + PushDownPredicate, + InferFiltersFromConstraints(SimpleCatalystConf(caseSensitiveAnalysis = true, + constraintPropagationEnabled = false)), CombineFilters) :: Nil } @@ -201,4 +212,10 @@ class InferFiltersFromConstraintsSuite extends PlanTest { val optimized = Optimize.execute(originalQuery) comparePlans(optimized, correctAnswer) } + + test("No inferred filter when constraint propagation is disabled") { + val originalQuery = testRelation.where('a === 1 && 'a === 'b).analyze + val optimized = OptimizeWithConstraintPropagationDisabled.execute(originalQuery) + comparePlans(optimized, originalQuery) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OuterJoinEliminationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OuterJoinEliminationSuite.scala index c168a55e40c5..cbabc1fa6d92 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OuterJoinEliminationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OuterJoinEliminationSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.optimizer +import org.apache.spark.sql.catalyst.SimpleCatalystConf import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ @@ -31,7 +32,17 @@ class OuterJoinEliminationSuite extends PlanTest { Batch("Subqueries", Once, EliminateSubqueryAliases) :: Batch("Outer Join Elimination", Once, - EliminateOuterJoin, + EliminateOuterJoin(SimpleCatalystConf(caseSensitiveAnalysis = true)), + PushPredicateThroughJoin) :: Nil + } + + object OptimizeWithConstraintPropagationDisabled extends RuleExecutor[LogicalPlan] { + val batches = + Batch("Subqueries", Once, + EliminateSubqueryAliases) :: + Batch("Outer Join Elimination", Once, + EliminateOuterJoin(SimpleCatalystConf(caseSensitiveAnalysis = true, + constraintPropagationEnabled = false)), PushPredicateThroughJoin) :: Nil } @@ -231,4 +242,21 @@ class OuterJoinEliminationSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + + test("no outer join elimination if constraint propagation is disabled") { + val x = testRelation.subquery('x) + val y = testRelation1.subquery('y) + + // The predicate "x.b + y.d >= 3" will be inferred constraints like: + // "x.b != null" and "y.d != null", if constraint propagation is enabled. + // When we disable it, the predicate can't be evaluated on left or right plan and used to + // filter out nulls. So the Outer Join will not be eliminated. + val originalQuery = + x.join(y, FullOuter, Option("x.a".attr === "y.d".attr)) + .where("x.b".attr + "y.d".attr >= 3) + + val optimized = OptimizeWithConstraintPropagationDisabled.execute(originalQuery.analyze) + + comparePlans(optimized, originalQuery.analyze) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala index 908dde7a6698..f771e3e9eba6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.SimpleCatalystConf import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.plans._ @@ -33,7 +34,7 @@ class PropagateEmptyRelationSuite extends PlanTest { ReplaceExceptWithAntiJoin, ReplaceIntersectWithSemiJoin, PushDownPredicate, - PruneFilters, + PruneFilters(SimpleCatalystConf(caseSensitiveAnalysis = true)), PropagateEmptyRelation) :: Nil } @@ -45,7 +46,7 @@ class PropagateEmptyRelationSuite extends PlanTest { ReplaceExceptWithAntiJoin, ReplaceIntersectWithSemiJoin, PushDownPredicate, - PruneFilters) :: Nil + PruneFilters(SimpleCatalystConf(caseSensitiveAnalysis = true))) :: Nil } val testRelation1 = LocalRelation.fromExternalRows(Seq('a.int), data = Seq(Row(1))) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PruneFiltersSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PruneFiltersSuite.scala index d8cfec539149..20f7f69e86c0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PruneFiltersSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PruneFiltersSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.optimizer +import org.apache.spark.sql.catalyst.SimpleCatalystConf import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ @@ -33,7 +34,19 @@ class PruneFiltersSuite extends PlanTest { EliminateSubqueryAliases) :: Batch("Filter Pushdown and Pruning", Once, CombineFilters, - PruneFilters, + PruneFilters(SimpleCatalystConf(caseSensitiveAnalysis = true)), + PushDownPredicate, + PushPredicateThroughJoin) :: Nil + } + + object OptimizeWithConstraintPropagationDisabled extends RuleExecutor[LogicalPlan] { + val batches = + Batch("Subqueries", Once, + EliminateSubqueryAliases) :: + Batch("Filter Pushdown and Pruning", Once, + CombineFilters, + PruneFilters(SimpleCatalystConf(caseSensitiveAnalysis = true, + constraintPropagationEnabled = false)), PushDownPredicate, PushPredicateThroughJoin) :: Nil } @@ -133,4 +146,29 @@ class PruneFiltersSuite extends PlanTest { val correctAnswer = testRelation.where(Rand(10) > 5).where(Rand(10) > 5).select('a).analyze comparePlans(optimized, correctAnswer) } + + test("No pruning when constraint propagation is disabled") { + val tr1 = LocalRelation('a.int, 'b.int, 'c.int).subquery('tr1) + val tr2 = LocalRelation('a.int, 'd.int, 'e.int).subquery('tr2) + + val query = tr1 + .where("tr1.a".attr > 10 || "tr1.c".attr < 10) + .join(tr2.where('d.attr < 100), Inner, Some("tr1.a".attr === "tr2.a".attr)) + + val queryWithUselessFilter = + query.where( + ("tr1.a".attr > 10 || "tr1.c".attr < 10) && + 'd.attr < 100) + + val optimized = + OptimizeWithConstraintPropagationDisabled.execute(queryWithUselessFilter.analyze) + // When constraint propagation is disabled, the useless filter won't be pruned. + // It gets pushed down. Because the rule `CombineFilters` runs only once, there are redundant + // and duplicate filters. + val correctAnswer = tr1 + .where("tr1.a".attr > 10 || "tr1.c".attr < 10).where("tr1.a".attr > 10 || "tr1.c".attr < 10) + .join(tr2.where('d.attr < 100).where('d.attr < 100), + Inner, Some("tr1.a".attr === "tr2.a".attr)).analyze + comparePlans(optimized, correctAnswer) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala index 21b7f49e14bd..ca4976f0d6db 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.optimizer +import org.apache.spark.sql.catalyst.SimpleCatalystConf import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ @@ -34,7 +35,7 @@ class SetOperationSuite extends PlanTest { CombineUnions, PushProjectionThroughUnion, PushDownPredicate, - PruneFilters) :: Nil + PruneFilters(SimpleCatalystConf(caseSensitiveAnalysis = true))) :: Nil } val testRelation = LocalRelation('a.int, 'b.int, 'c.int) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala index 908b37040828..4061394b862a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala @@ -397,4 +397,22 @@ class ConstraintPropagationSuite extends SparkFunSuite { IsNotNull(resolveColumn(tr, "a")), IsNotNull(resolveColumn(tr, "c"))))) } + + test("enable/disable constraint propagation") { + val tr = LocalRelation('a.int, 'b.string, 'c.int) + val filterRelation = tr.where('a.attr > 10) + + verifyConstraints( + filterRelation.analyze.getConstraints(constraintPropagationEnabled = true), + filterRelation.analyze.constraints) + + assert(filterRelation.analyze.getConstraints(constraintPropagationEnabled = false).isEmpty) + + val aliasedRelation = tr.where('c.attr > 10 && 'a.attr < 5) + .groupBy('a, 'c, 'b)('a, 'c.as("c1"), count('a).as("a3")).select('c1, 'a, 'a3) + + verifyConstraints(aliasedRelation.analyze.getConstraints(constraintPropagationEnabled = true), + aliasedRelation.analyze.constraints) + assert(aliasedRelation.analyze.getConstraints(constraintPropagationEnabled = false).isEmpty) + } } From f88f56b835b3a61ff2d59236e7fa05eda5aefcaa Mon Sep 17 00:00:00 2001 From: Roxanne Moslehi Date: Sat, 25 Mar 2017 00:10:30 +0100 Subject: [PATCH 127/512] [DOCS] Clarify round mode for format_number & round functions ## What changes were proposed in this pull request? Updated the description for the `format_number` description to indicate that it uses `HALF_EVEN` rounding. Updated the description for the `round` description to indicate that it uses `HALF_UP` rounding. ## How was this patch tested? Just changing the two function comments so no testing involved. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Roxanne Moslehi Author: roxannemoslehi Closes #17399 from roxannemoslehi/patch-1. --- .../main/scala/org/apache/spark/sql/functions.scala | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 66bb8816a670..acdb8e2d3edc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1861,7 +1861,7 @@ object functions { def rint(columnName: String): Column = rint(Column(columnName)) /** - * Returns the value of the column `e` rounded to 0 decimal places. + * Returns the value of the column `e` rounded to 0 decimal places with HALF_UP round mode. * * @group math_funcs * @since 1.5.0 @@ -1869,8 +1869,8 @@ object functions { def round(e: Column): Column = round(e, 0) /** - * Round the value of `e` to `scale` decimal places if `scale` is greater than or equal to 0 - * or at integral part when `scale` is less than 0. + * Round the value of `e` to `scale` decimal places with HALF_UP round mode + * if `scale` is greater than or equal to 0 or at integral part when `scale` is less than 0. * * @group math_funcs * @since 1.5.0 @@ -2191,8 +2191,8 @@ object functions { } /** - * Formats numeric column x to a format like '#,###,###.##', rounded to d decimal places, - * and returns the result as a string column. + * Formats numeric column x to a format like '#,###,###.##', rounded to d decimal places + * with HALF_EVEN round mode, and returns the result as a string column. * * If d is 0, the result has no decimal point or fractional part. * If d is less than 0, the result will be null. From 0a6c50711b871dce1a04f5dc7652a0b936369fa0 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Sat, 25 Mar 2017 01:07:50 +0100 Subject: [PATCH 128/512] [SPARK-20070][SQL] Fix 2.10 build ## What changes were proposed in this pull request? Commit https://github.com/apache/spark/commit/91fa80fe8a2480d64c430bd10f97b3d44c007bcc broke the build for scala 2.10. The commit uses `Regex.regex` field which is not available in Scala 2.10. This PR fixes this. ## How was this patch tested? Existing tests. Author: Herman van Hovell Closes #17420 from hvanhovell/SPARK-20070-2.0. --- .../org/apache/spark/internal/config/ConfigBuilder.scala | 2 +- .../org/apache/spark/internal/config/ConfigEntrySuite.scala | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala b/core/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala index d87619afd3b2..b9921138cc6c 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala @@ -225,6 +225,6 @@ private[spark] case class ConfigBuilder(key: String) { } def regexConf: TypedConfigBuilder[Regex] = { - new TypedConfigBuilder(this, regexFromString(_, this.key), _.regex) + new TypedConfigBuilder(this, regexFromString(_, this.key), _.toString) } } diff --git a/core/src/test/scala/org/apache/spark/internal/config/ConfigEntrySuite.scala b/core/src/test/scala/org/apache/spark/internal/config/ConfigEntrySuite.scala index f3756b21080b..3ff7e84d73bd 100644 --- a/core/src/test/scala/org/apache/spark/internal/config/ConfigEntrySuite.scala +++ b/core/src/test/scala/org/apache/spark/internal/config/ConfigEntrySuite.scala @@ -100,10 +100,10 @@ class ConfigEntrySuite extends SparkFunSuite { val rConf = ConfigBuilder(testKey("regex")).regexConf.createWithDefault(".*".r) conf.set(rConf, "[0-9a-f]{8}".r) - assert(conf.get(rConf).regex === "[0-9a-f]{8}") + assert(conf.get(rConf).toString === "[0-9a-f]{8}") conf.set(rConf.key, "[0-9a-f]{4}") - assert(conf.get(rConf).regex === "[0-9a-f]{4}") + assert(conf.get(rConf).toString === "[0-9a-f]{4}") conf.set(rConf.key, "[.") val e = intercept[IllegalArgumentException](conf.get(rConf)) From a2ce0a2e309e70d74ae5d2ed203f7919a0f79397 Mon Sep 17 00:00:00 2001 From: Xiao Li Date: Fri, 24 Mar 2017 23:27:42 -0700 Subject: [PATCH 129/512] [HOTFIX][SQL] Fix the failed test cases in GeneratorFunctionSuite ### What changes were proposed in this pull request? Multiple tests failed. Revert the changes on `supportCodegen` of `GenerateExec`. For example, - https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder/75194/testReport/ ### How was this patch tested? N/A Author: Xiao Li Closes #17425 from gatorsmile/turnOnCodeGenGenerateExec. --- .../apache/spark/sql/GeneratorFunctionSuite.scala | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala index b9871afd59e4..cef5bbf0e85a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala @@ -91,7 +91,7 @@ class GeneratorFunctionSuite extends QueryTest with SharedSQLContext { val df = Seq((1, Seq(1, 2, 3)), (2, Seq())).toDF("a", "intList") checkAnswer( df.select(explode_outer('intList)), - Row(1) :: Row(2) :: Row(3) :: Row(null) :: Nil) + Row(1) :: Row(2) :: Row(3) :: Nil) } test("single posexplode") { @@ -105,7 +105,7 @@ class GeneratorFunctionSuite extends QueryTest with SharedSQLContext { val df = Seq((1, Seq(1, 2, 3)), (2, Seq())).toDF("a", "intList") checkAnswer( df.select(posexplode_outer('intList)), - Row(0, 1) :: Row(1, 2) :: Row(2, 3) :: Row(null, null) :: Nil) + Row(0, 1) :: Row(1, 2) :: Row(2, 3) :: Nil) } test("explode and other columns") { @@ -161,7 +161,7 @@ class GeneratorFunctionSuite extends QueryTest with SharedSQLContext { checkAnswer( df.select(explode_outer('intList).as('int)).select('int), - Row(1) :: Row(2) :: Row(3) :: Row(null) :: Nil) + Row(1) :: Row(2) :: Row(3) :: Nil) checkAnswer( df.select(explode('intList).as('int)).select(sum('int)), @@ -182,7 +182,7 @@ class GeneratorFunctionSuite extends QueryTest with SharedSQLContext { checkAnswer( df.select(explode_outer('map)), - Row("a", "b") :: Row(null, null) :: Row("c", "d") :: Nil) + Row("a", "b") :: Row("c", "d") :: Nil) } test("explode on map with aliases") { @@ -198,7 +198,7 @@ class GeneratorFunctionSuite extends QueryTest with SharedSQLContext { checkAnswer( df.select(explode_outer('map).as("key1" :: "value1" :: Nil)).select("key1", "value1"), - Row("a", "b") :: Row(null, null) :: Nil) + Row("a", "b") :: Nil) } test("self join explode") { @@ -279,7 +279,7 @@ class GeneratorFunctionSuite extends QueryTest with SharedSQLContext { ) checkAnswer( df2.selectExpr("inline_outer(col1)"), - Row(null, null) :: Row(3, "4") :: Row(5, "6") :: Nil + Row(3, "4") :: Row(5, "6") :: Nil ) } From e8ddb91c7ea5a0b4576cf47aaf969bcc82860b7c Mon Sep 17 00:00:00 2001 From: Kalvin Chau Date: Sat, 25 Mar 2017 10:42:15 +0000 Subject: [PATCH 130/512] [SPARK-20078][MESOS] Mesos executor configurability for task name and labels ## What changes were proposed in this pull request? Adding configurable mesos executor names and labels using `spark.mesos.task.name` and `spark.mesos.task.labels`. Labels were defined as `k1:v1,k2:v2`. mgummelt ## How was this patch tested? Added unit tests to verify labels were added correctly, with incorrect labels being ignored and added a test to test the name of the executor. Tested with: `./build/sbt -Pmesos mesos/test` Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Kalvin Chau Closes #17404 from kalvinnchau/mesos-config. --- .../mesos/MesosCoarseGrainedSchedulerBackend.scala | 3 ++- .../MesosCoarseGrainedSchedulerBackendSuite.scala | 11 +++++++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala index c049a32eabf9..5bdc2a2b840e 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala @@ -403,7 +403,8 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( .setTaskId(TaskID.newBuilder().setValue(taskId.toString).build()) .setSlaveId(offer.getSlaveId) .setCommand(createCommand(offer, taskCPUs + extraCoresPerExecutor, taskId)) - .setName("Task " + taskId) + .setName(s"${sc.appName} $taskId") + taskBuilder.addAllResources(resourcesToUse.asJava) taskBuilder.setContainer(MesosSchedulerBackendUtil.containerInfo(sc.conf)) diff --git a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala index 98033bec6dd6..eb83926ae410 100644 --- a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala +++ b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala @@ -464,6 +464,17 @@ class MesosCoarseGrainedSchedulerBackendSuite extends SparkFunSuite assert(!uris.asScala.head.getCache) } + test("mesos sets task name to spark.app.name") { + setBackend() + + val offers = List(Resources(backend.executorMemory(sc), 1)) + offerResources(offers) + val launchedTasks = verifyTaskLaunched(driver, "o1") + + // Add " 0" to the taskName to match the executor number that is appended + assert(launchedTasks.head.getName == "test-mesos-dynamic-alloc 0") + } + test("mesos supports spark.mesos.network.name") { setBackend(Map( "spark.mesos.network.name" -> "test-network-name" From be85245a98d58f636ff54956cdfde15ea5cd6122 Mon Sep 17 00:00:00 2001 From: sethah Date: Sat, 25 Mar 2017 17:41:59 +0000 Subject: [PATCH 131/512] [SPARK-17137][ML][WIP] Compress logistic regression coefficients ## What changes were proposed in this pull request? Use the new `compressed` method on matrices to store the logistic regression coefficients as sparse or dense - whichever is requires less memory. Marked as WIP so we can add some performance test results. Basically, we should see if prediction is slower because of using a sparse matrix over a dense one. This can happen since sparse matrices do not use native BLAS operations when computing the margins. ## How was this patch tested? Unit tests added. Author: sethah Closes #17426 from sethah/SPARK-17137. --- .../classification/LogisticRegression.scala | 28 ++------- .../LogisticRegressionSuite.scala | 58 ++++++++++++++----- 2 files changed, 49 insertions(+), 37 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 1a78187d4f8e..7b56bce41c32 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -399,14 +399,9 @@ class LogisticRegression @Since("1.2.0") ( logWarning(s"All labels are the same value and fitIntercept=true, so the coefficients " + s"will be zeros. Training is not needed.") val constantLabelIndex = Vectors.dense(histogram).argmax - // TODO: use `compressed` after SPARK-17471 - val coefMatrix = if (numFeatures < numCoefficientSets) { - new SparseMatrix(numCoefficientSets, numFeatures, - Array.fill(numFeatures + 1)(0), Array.empty[Int], Array.empty[Double]) - } else { - new SparseMatrix(numCoefficientSets, numFeatures, Array.fill(numCoefficientSets + 1)(0), - Array.empty[Int], Array.empty[Double], isTransposed = true) - } + val coefMatrix = new SparseMatrix(numCoefficientSets, numFeatures, + new Array[Int](numCoefficientSets + 1), Array.empty[Int], Array.empty[Double], + isTransposed = true).compressed val interceptVec = if (isMultinomial) { Vectors.sparse(numClasses, Seq((constantLabelIndex, Double.PositiveInfinity))) } else { @@ -617,26 +612,13 @@ class LogisticRegression @Since("1.2.0") ( denseCoefficientMatrix.update(_ - coefficientMean) } - // TODO: use `denseCoefficientMatrix.compressed` after SPARK-17471 - val compressedCoefficientMatrix = if (isMultinomial) { - denseCoefficientMatrix - } else { - val compressedVector = Vectors.dense(denseCoefficientMatrix.values).compressed - compressedVector match { - case dv: DenseVector => denseCoefficientMatrix - case sv: SparseVector => - new SparseMatrix(1, numFeatures, Array(0, sv.indices.length), sv.indices, sv.values, - isTransposed = true) - } - } - // center the intercepts when using multinomial algorithm if ($(fitIntercept) && isMultinomial) { val interceptArray = interceptVec.toArray val interceptMean = interceptArray.sum / interceptArray.length (0 until interceptVec.size).foreach { i => interceptArray(i) -= interceptMean } } - (compressedCoefficientMatrix, interceptVec.compressed, arrayBuilder.result()) + (denseCoefficientMatrix.compressed, interceptVec.compressed, arrayBuilder.result()) } } @@ -713,7 +695,7 @@ class LogisticRegressionModel private[spark] ( // convert to appropriate vector representation without replicating data private lazy val _coefficients: Vector = { require(coefficientMatrix.isTransposed, - "LogisticRegressionModel coefficients should be row major.") + "LogisticRegressionModel coefficients should be row major for binomial model.") coefficientMatrix match { case dm: DenseMatrix => Vectors.dense(dm.values) case sm: SparseMatrix => Vectors.sparse(coefficientMatrix.numCols, sm.rowIndices, sm.values) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index affaa573749e..1b6448037349 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -26,7 +26,7 @@ import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.attribute.NominalAttribute import org.apache.spark.ml.classification.LogisticRegressionSuite._ import org.apache.spark.ml.feature.{Instance, LabeledPoint} -import org.apache.spark.ml.linalg.{DenseMatrix, Matrices, SparseMatrix, SparseVector, Vector, Vectors} +import org.apache.spark.ml.linalg.{DenseMatrix, Matrices, SparseMatrix, Vector, Vectors} import org.apache.spark.ml.param.{ParamMap, ParamsSuite} import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ @@ -713,8 +713,6 @@ class LogisticRegressionSuite assert(model2.intercept ~== interceptR relTol 1E-2) assert(model2.coefficients ~== coefficientsR absTol 1E-3) - // TODO: move this to a standalone test of compression after SPARK-17471 - assert(model2.coefficients.isInstanceOf[SparseVector]) } test("binary logistic regression without intercept with L1 regularization") { @@ -2031,29 +2029,61 @@ class LogisticRegressionSuite // TODO: check num iters is zero when it become available in the model } - test("compressed storage") { + test("compressed storage for constant label") { + /* + When the label is constant and fit intercept is true, all the coefficients will be + zeros, and so the model coefficients should be stored as sparse data structures, except + when the matrix dimensions are very small. + */ val moreClassesThanFeatures = Seq( - LabeledPoint(4.0, Vectors.dense(0.0, 0.0, 0.0)), - LabeledPoint(4.0, Vectors.dense(1.0, 1.0, 1.0)), - LabeledPoint(4.0, Vectors.dense(2.0, 2.0, 2.0))).toDF() - val mlr = new LogisticRegression().setFamily("multinomial") + LabeledPoint(4.0, Vectors.dense(Array.fill(5)(0.0))), + LabeledPoint(4.0, Vectors.dense(Array.fill(5)(1.0))), + LabeledPoint(4.0, Vectors.dense(Array.fill(5)(2.0)))).toDF() + val mlr = new LogisticRegression().setFamily("multinomial").setFitIntercept(true) val model = mlr.fit(moreClassesThanFeatures) assert(model.coefficientMatrix.isInstanceOf[SparseMatrix]) - assert(model.coefficientMatrix.asInstanceOf[SparseMatrix].colPtrs.length === 4) + assert(model.coefficientMatrix.isColMajor) + + // in this case, it should be stored as row major val moreFeaturesThanClasses = Seq( - LabeledPoint(1.0, Vectors.dense(0.0, 0.0, 0.0)), - LabeledPoint(1.0, Vectors.dense(1.0, 1.0, 1.0)), - LabeledPoint(1.0, Vectors.dense(2.0, 2.0, 2.0))).toDF() + LabeledPoint(1.0, Vectors.dense(Array.fill(5)(0.0))), + LabeledPoint(1.0, Vectors.dense(Array.fill(5)(1.0))), + LabeledPoint(1.0, Vectors.dense(Array.fill(5)(2.0)))).toDF() val model2 = mlr.fit(moreFeaturesThanClasses) assert(model2.coefficientMatrix.isInstanceOf[SparseMatrix]) - assert(model2.coefficientMatrix.asInstanceOf[SparseMatrix].colPtrs.length === 3) + assert(model2.coefficientMatrix.isRowMajor) - val blr = new LogisticRegression().setFamily("binomial") + val blr = new LogisticRegression().setFamily("binomial").setFitIntercept(true) val blrModel = blr.fit(moreFeaturesThanClasses) assert(blrModel.coefficientMatrix.isInstanceOf[SparseMatrix]) assert(blrModel.coefficientMatrix.asInstanceOf[SparseMatrix].colPtrs.length === 2) } + test("compressed coefficients") { + + val trainer1 = new LogisticRegression() + .setRegParam(0.1) + .setElasticNetParam(1.0) + + // compressed row major is optimal + val model1 = trainer1.fit(multinomialDataset.limit(100)) + assert(model1.coefficientMatrix.isInstanceOf[SparseMatrix]) + assert(model1.coefficientMatrix.isRowMajor) + + // compressed column major is optimal since there are more classes than features + val labelMeta = NominalAttribute.defaultAttr.withName("label").withNumValues(6).toMetadata() + val model2 = trainer1.fit(multinomialDataset + .withColumn("label", col("label").as("label", labelMeta)).limit(100)) + assert(model2.coefficientMatrix.isInstanceOf[SparseMatrix]) + assert(model2.coefficientMatrix.isColMajor) + + // coefficients are dense without L1 regularization + val trainer2 = new LogisticRegression() + .setElasticNetParam(0.0) + val model3 = trainer2.fit(multinomialDataset.limit(100)) + assert(model3.coefficientMatrix.isInstanceOf[DenseMatrix]) + } + test("numClasses specified in metadata/inferred") { val lr = new LogisticRegression().setMaxIter(1).setFamily("multinomial") From 0b903caef3183c5113feb09995874f6a07aa6698 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Sat, 25 Mar 2017 11:46:54 -0700 Subject: [PATCH 132/512] [SPARK-19949][SQL][FOLLOW-UP] move FailureSafeParser from catalyst to sql core ## What changes were proposed in this pull request? The `FailureSafeParser` is only used in sql core, it doesn't make sense to put it in catalyst module. ## How was this patch tested? N/A Author: Wenchen Fan Closes #17408 from cloud-fan/minor. --- .../catalyst/util/BadRecordException.scala | 33 +++++++++++++++++++ .../apache/spark/sql/DataFrameReader.scala | 3 +- .../datasources}/FailureSafeParser.scala | 15 ++------- .../datasources/csv/UnivocityParser.scala | 3 +- .../datasources/json/JsonDataSource.scala | 3 +- 5 files changed, 39 insertions(+), 18 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/BadRecordException.scala rename sql/{catalyst/src/main/scala/org/apache/spark/sql/catalyst/util => core/src/main/scala/org/apache/spark/sql/execution/datasources}/FailureSafeParser.scala (82%) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/BadRecordException.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/BadRecordException.scala new file mode 100644 index 000000000000..985f0dc1cd60 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/BadRecordException.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.unsafe.types.UTF8String + +/** + * Exception thrown when the underlying parser meet a bad record and can't parse it. + * @param record a function to return the record that cause the parser to fail + * @param partialResult a function that returns an optional row, which is the partial result of + * parsing this bad record. + * @param cause the actual exception about why the record is bad and can't be parsed. + */ +case class BadRecordException( + record: () => UTF8String, + partialResult: () => Option[InternalRow], + cause: Throwable) extends Exception(cause) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index e6d2b1bc28d9..6c238618f2af 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -27,11 +27,10 @@ import org.apache.spark.Partition import org.apache.spark.annotation.InterfaceStability import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JSONOptions} -import org.apache.spark.sql.catalyst.util.FailureSafeParser import org.apache.spark.sql.execution.LogicalRDD import org.apache.spark.sql.execution.command.DDLUtils +import org.apache.spark.sql.execution.datasources.{DataSource, FailureSafeParser} import org.apache.spark.sql.execution.datasources.csv._ -import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.datasources.jdbc._ import org.apache.spark.sql.execution.datasources.json.TextInputJsonDataSource import org.apache.spark.sql.types.{StringType, StructType} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/FailureSafeParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FailureSafeParser.scala similarity index 82% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/FailureSafeParser.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FailureSafeParser.scala index 725e3015b341..159aef220be1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/FailureSafeParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FailureSafeParser.scala @@ -15,10 +15,11 @@ * limitations under the License. */ -package org.apache.spark.sql.catalyst.util +package org.apache.spark.sql.execution.datasources import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.GenericInternalRow +import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.types.StructType import org.apache.spark.unsafe.types.UTF8String @@ -69,15 +70,3 @@ class FailureSafeParser[IN]( } } } - -/** - * Exception thrown when the underlying parser meet a bad record and can't parse it. - * @param record a function to return the record that cause the parser to fail - * @param partialResult a function that returns an optional row, which is the partial result of - * parsing this bad record. - * @param cause the actual exception about why the record is bad and can't be parsed. - */ -case class BadRecordException( - record: () => UTF8String, - partialResult: () => Option[InternalRow], - cause: Throwable) extends Exception(cause) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala index 263f77e11c4d..c3657acb7d86 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala @@ -30,7 +30,8 @@ import com.univocity.parsers.csv.CsvParser import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.GenericInternalRow -import org.apache.spark.sql.catalyst.util.{BadRecordException, DateTimeUtils, FailureSafeParser} +import org.apache.spark.sql.catalyst.util.{BadRecordException, DateTimeUtils} +import org.apache.spark.sql.execution.datasources.FailureSafeParser import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala index 51e952c12202..4f2963da9ace 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala @@ -33,8 +33,7 @@ import org.apache.spark.rdd.{BinaryFileRDD, RDD} import org.apache.spark.sql.{AnalysisException, Dataset, Encoders, SparkSession} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JSONOptions} -import org.apache.spark.sql.catalyst.util.FailureSafeParser -import org.apache.spark.sql.execution.datasources.{CodecStreams, DataSource, HadoopFileLinesReader, PartitionedFile} +import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.text.TextFileFormat import org.apache.spark.sql.types.StructType import org.apache.spark.unsafe.types.UTF8String From 2422c86f2ce2dd649b1d63062ec5c5fc1716c519 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Sat, 25 Mar 2017 23:29:02 -0700 Subject: [PATCH 133/512] [SPARK-20092][R][PROJECT INFRA] Add the detection for Scala codes dedicated for R in AppVeyor tests ## What changes were proposed in this pull request? We are currently detecting the changes in `R/` directory only and then trigger AppVeyor tests. It seems we need to tests when there are Scala codes dedicated for R in `core/src/main/scala/org/apache/spark/api/r/`, `sql/core/src/main/scala/org/apache/spark/sql/api/r/` and `mllib/src/main/scala/org/apache/spark/ml/r/` too. This will enables the tests, for example, for SPARK-20088. ## How was this patch tested? Tests with manually created PRs. - Changes in `sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala` https://github.com/spark-test/spark/pull/13 - Changes in `core/src/main/scala/org/apache/spark/api/r/SerDe.scala` https://github.com/spark-test/spark/pull/12 - Changes in `README.md` https://github.com/spark-test/spark/pull/14 Author: hyukjinkwon Closes #17427 from HyukjinKwon/SPARK-20092. --- appveyor.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/appveyor.yml b/appveyor.yml index 5adf1b4bedb4..bbb27589cad0 100644 --- a/appveyor.yml +++ b/appveyor.yml @@ -27,6 +27,9 @@ branches: only_commits: files: - R/ + - sql/core/src/main/scala/org/apache/spark/sql/api/r/ + - core/src/main/scala/org/apache/spark/api/r/ + - mllib/src/main/scala/org/apache/spark/ml/r/ cache: - C:\Users\appveyor\.m2 From 93bb0b911b6c790fa369b39da51a83d8f62da909 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Sun, 26 Mar 2017 09:20:22 +0200 Subject: [PATCH 134/512] [SPARK-20046][SQL] Facilitate loop optimizations in a JIT compiler regarding sqlContext.read.parquet() ## What changes were proposed in this pull request? This PR improves performance of operations with `sqlContext.read.parquet()` by changing Java code generated by Catalyst. This PR is inspired by [the blog article](https://databricks.com/blog/2017/02/16/processing-trillion-rows-per-second-single-machine-can-nested-loop-joins-fast.html) and [this stackoverflow entry](http://stackoverflow.com/questions/40629435/fast-parquet-row-count-in-spark). This PR changes generated code in the following two points. 1. Replace a while-loop with long instance variables a for-loop with int local variables 2. Suppress generation of `shouldStop()` method if this method is unnecessary (e.g. `append()` is not generated). These points facilitates compiler optimizations in a JIT compiler by feeding the simplified Java code into the JIT compiler. The performance of `sqlContext.read.parquet().count` is improved by 1.09x. Benchmark program: ```java val dir = "/dev/shm/parquet" val N = 1000 * 1000 * 40 val iters = 20 val benchmark = new Benchmark("Parquet", N * iters, minNumIters = 5, warmupTime = 30.seconds) sparkSession.range(n).write.mode("overwrite").parquet(dir) benchmark.addCase("count") { i: Int => var n = 0 var len = 0L while (n < iters) { len += sparkSession.read.parquet(dir).count n += 1 } } benchmark.run ``` Performance result without this PR ``` OpenJDK 64-Bit Server VM 1.8.0_121-8u121-b13-0ubuntu1.16.04.2-b13 on Linux 4.4.0-47-generic Intel(R) Xeon(R) CPU E5-2667 v3 3.20GHz Parquet: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ w/o this PR 1152 / 1211 694.7 1.4 1.0X ``` Performance result with this PR ``` OpenJDK 64-Bit Server VM 1.8.0_121-8u121-b13-0ubuntu1.16.04.2-b13 on Linux 4.4.0-47-generic Intel(R) Xeon(R) CPU E5-2667 v3 3.20GHz Parquet: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ with this PR 1053 / 1121 760.0 1.3 1.0X ``` Here is a comparison between generated code w/o and with this PR. Only the method ```agg_doAggregateWithoutKey``` is changed. Generated code without this PR ```java /* 005 */ final class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator { /* 006 */ private Object[] references; /* 007 */ private scala.collection.Iterator[] inputs; /* 008 */ private boolean agg_initAgg; /* 009 */ private boolean agg_bufIsNull; /* 010 */ private long agg_bufValue; /* 011 */ private scala.collection.Iterator scan_input; /* 012 */ private org.apache.spark.sql.execution.metric.SQLMetric scan_numOutputRows; /* 013 */ private org.apache.spark.sql.execution.metric.SQLMetric scan_scanTime; /* 014 */ private long scan_scanTime1; /* 015 */ private org.apache.spark.sql.execution.vectorized.ColumnarBatch scan_batch; /* 016 */ private int scan_batchIdx; /* 017 */ private org.apache.spark.sql.execution.metric.SQLMetric agg_numOutputRows; /* 018 */ private org.apache.spark.sql.execution.metric.SQLMetric agg_aggTime; /* 019 */ private UnsafeRow agg_result; /* 020 */ private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder agg_holder; /* 021 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter agg_rowWriter; /* 022 */ /* 023 */ public GeneratedIterator(Object[] references) { /* 024 */ this.references = references; /* 025 */ } /* 026 */ /* 027 */ public void init(int index, scala.collection.Iterator[] inputs) { /* 028 */ partitionIndex = index; /* 029 */ this.inputs = inputs; /* 030 */ agg_initAgg = false; /* 031 */ /* 032 */ scan_input = inputs[0]; /* 033 */ this.scan_numOutputRows = (org.apache.spark.sql.execution.metric.SQLMetric) references[0]; /* 034 */ this.scan_scanTime = (org.apache.spark.sql.execution.metric.SQLMetric) references[1]; /* 035 */ scan_scanTime1 = 0; /* 036 */ scan_batch = null; /* 037 */ scan_batchIdx = 0; /* 038 */ this.agg_numOutputRows = (org.apache.spark.sql.execution.metric.SQLMetric) references[2]; /* 039 */ this.agg_aggTime = (org.apache.spark.sql.execution.metric.SQLMetric) references[3]; /* 040 */ agg_result = new UnsafeRow(1); /* 041 */ this.agg_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(agg_result, 0); /* 042 */ this.agg_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(agg_holder, 1); /* 043 */ /* 044 */ } /* 045 */ /* 046 */ private void agg_doAggregateWithoutKey() throws java.io.IOException { /* 047 */ // initialize aggregation buffer /* 048 */ agg_bufIsNull = false; /* 049 */ agg_bufValue = 0L; /* 050 */ /* 051 */ if (scan_batch == null) { /* 052 */ scan_nextBatch(); /* 053 */ } /* 054 */ while (scan_batch != null) { /* 055 */ int numRows = scan_batch.numRows(); /* 056 */ while (scan_batchIdx < numRows) { /* 057 */ int scan_rowIdx = scan_batchIdx++; /* 058 */ // do aggregate /* 059 */ // common sub-expressions /* 060 */ /* 061 */ // evaluate aggregate function /* 062 */ boolean agg_isNull1 = false; /* 063 */ /* 064 */ long agg_value1 = -1L; /* 065 */ agg_value1 = agg_bufValue + 1L; /* 066 */ // update aggregation buffer /* 067 */ agg_bufIsNull = false; /* 068 */ agg_bufValue = agg_value1; /* 069 */ if (shouldStop()) return; /* 070 */ } /* 071 */ scan_batch = null; /* 072 */ scan_nextBatch(); /* 073 */ } /* 074 */ scan_scanTime.add(scan_scanTime1 / (1000 * 1000)); /* 075 */ scan_scanTime1 = 0; /* 076 */ /* 077 */ } /* 078 */ /* 079 */ private void scan_nextBatch() throws java.io.IOException { /* 080 */ long getBatchStart = System.nanoTime(); /* 081 */ if (scan_input.hasNext()) { /* 082 */ scan_batch = (org.apache.spark.sql.execution.vectorized.ColumnarBatch)scan_input.next(); /* 083 */ scan_numOutputRows.add(scan_batch.numRows()); /* 084 */ scan_batchIdx = 0; /* 085 */ /* 086 */ } /* 087 */ scan_scanTime1 += System.nanoTime() - getBatchStart; /* 088 */ } /* 089 */ /* 090 */ protected void processNext() throws java.io.IOException { /* 091 */ while (!agg_initAgg) { /* 092 */ agg_initAgg = true; /* 093 */ long agg_beforeAgg = System.nanoTime(); /* 094 */ agg_doAggregateWithoutKey(); /* 095 */ agg_aggTime.add((System.nanoTime() - agg_beforeAgg) / 1000000); /* 096 */ /* 097 */ // output the result /* 098 */ /* 099 */ agg_numOutputRows.add(1); /* 100 */ agg_rowWriter.zeroOutNullBytes(); /* 101 */ /* 102 */ if (agg_bufIsNull) { /* 103 */ agg_rowWriter.setNullAt(0); /* 104 */ } else { /* 105 */ agg_rowWriter.write(0, agg_bufValue); /* 106 */ } /* 107 */ append(agg_result); /* 108 */ } /* 109 */ } /* 110 */ } ``` Generated code with this PR ```java /* 005 */ final class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator { /* 006 */ private Object[] references; /* 007 */ private scala.collection.Iterator[] inputs; /* 008 */ private boolean agg_initAgg; /* 009 */ private boolean agg_bufIsNull; /* 010 */ private long agg_bufValue; /* 011 */ private scala.collection.Iterator scan_input; /* 012 */ private org.apache.spark.sql.execution.metric.SQLMetric scan_numOutputRows; /* 013 */ private org.apache.spark.sql.execution.metric.SQLMetric scan_scanTime; /* 014 */ private long scan_scanTime1; /* 015 */ private org.apache.spark.sql.execution.vectorized.ColumnarBatch scan_batch; /* 016 */ private int scan_batchIdx; /* 017 */ private org.apache.spark.sql.execution.metric.SQLMetric agg_numOutputRows; /* 018 */ private org.apache.spark.sql.execution.metric.SQLMetric agg_aggTime; /* 019 */ private UnsafeRow agg_result; /* 020 */ private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder agg_holder; /* 021 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter agg_rowWriter; /* 022 */ /* 023 */ public GeneratedIterator(Object[] references) { /* 024 */ this.references = references; /* 025 */ } /* 026 */ /* 027 */ public void init(int index, scala.collection.Iterator[] inputs) { /* 028 */ partitionIndex = index; /* 029 */ this.inputs = inputs; /* 030 */ agg_initAgg = false; /* 031 */ /* 032 */ scan_input = inputs[0]; /* 033 */ this.scan_numOutputRows = (org.apache.spark.sql.execution.metric.SQLMetric) references[0]; /* 034 */ this.scan_scanTime = (org.apache.spark.sql.execution.metric.SQLMetric) references[1]; /* 035 */ scan_scanTime1 = 0; /* 036 */ scan_batch = null; /* 037 */ scan_batchIdx = 0; /* 038 */ this.agg_numOutputRows = (org.apache.spark.sql.execution.metric.SQLMetric) references[2]; /* 039 */ this.agg_aggTime = (org.apache.spark.sql.execution.metric.SQLMetric) references[3]; /* 040 */ agg_result = new UnsafeRow(1); /* 041 */ this.agg_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(agg_result, 0); /* 042 */ this.agg_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(agg_holder, 1); /* 043 */ /* 044 */ } /* 045 */ /* 046 */ private void agg_doAggregateWithoutKey() throws java.io.IOException { /* 047 */ // initialize aggregation buffer /* 048 */ agg_bufIsNull = false; /* 049 */ agg_bufValue = 0L; /* 050 */ /* 051 */ if (scan_batch == null) { /* 052 */ scan_nextBatch(); /* 053 */ } /* 054 */ while (scan_batch != null) { /* 055 */ int numRows = scan_batch.numRows(); /* 056 */ int scan_localEnd = numRows - scan_batchIdx; /* 057 */ for (int scan_localIdx = 0; scan_localIdx < scan_localEnd; scan_localIdx++) { /* 058 */ int scan_rowIdx = scan_batchIdx + scan_localIdx; /* 059 */ // do aggregate /* 060 */ // common sub-expressions /* 061 */ /* 062 */ // evaluate aggregate function /* 063 */ boolean agg_isNull1 = false; /* 064 */ /* 065 */ long agg_value1 = -1L; /* 066 */ agg_value1 = agg_bufValue + 1L; /* 067 */ // update aggregation buffer /* 068 */ agg_bufIsNull = false; /* 069 */ agg_bufValue = agg_value1; /* 070 */ // shouldStop check is eliminated /* 071 */ } /* 072 */ scan_batchIdx = numRows; /* 073 */ scan_batch = null; /* 074 */ scan_nextBatch(); /* 075 */ } /* 079 */ } /* 080 */ /* 081 */ private void scan_nextBatch() throws java.io.IOException { /* 082 */ long getBatchStart = System.nanoTime(); /* 083 */ if (scan_input.hasNext()) { /* 084 */ scan_batch = (org.apache.spark.sql.execution.vectorized.ColumnarBatch)scan_input.next(); /* 085 */ scan_numOutputRows.add(scan_batch.numRows()); /* 086 */ scan_batchIdx = 0; /* 087 */ /* 088 */ } /* 089 */ scan_scanTime1 += System.nanoTime() - getBatchStart; /* 090 */ } /* 091 */ /* 092 */ protected void processNext() throws java.io.IOException { /* 093 */ while (!agg_initAgg) { /* 094 */ agg_initAgg = true; /* 095 */ long agg_beforeAgg = System.nanoTime(); /* 096 */ agg_doAggregateWithoutKey(); /* 097 */ agg_aggTime.add((System.nanoTime() - agg_beforeAgg) / 1000000); /* 098 */ /* 099 */ // output the result /* 100 */ /* 101 */ agg_numOutputRows.add(1); /* 102 */ agg_rowWriter.zeroOutNullBytes(); /* 103 */ /* 104 */ if (agg_bufIsNull) { /* 105 */ agg_rowWriter.setNullAt(0); /* 106 */ } else { /* 107 */ agg_rowWriter.write(0, agg_bufValue); /* 108 */ } /* 109 */ append(agg_result); /* 110 */ } /* 111 */ } /* 112 */ } ``` ## How was this patch tested? Tested existing test suites Author: Kazuaki Ishizaki Closes #17378 from kiszk/SPARK-20046. --- .../sql/execution/ColumnarBatchScan.scala | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala index 04fba17be4bf..e86116680a57 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala @@ -111,17 +111,27 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport { val columnsBatchInput = (output zip colVars).map { case (attr, colVar) => genCodeColumnVector(ctx, colVar, rowidx, attr.dataType, attr.nullable) } + val localIdx = ctx.freshName("localIdx") + val localEnd = ctx.freshName("localEnd") + val numRows = ctx.freshName("numRows") + val shouldStop = if (isShouldStopRequired) { + s"if (shouldStop()) { $idx = $rowidx + 1; return; }" + } else { + "// shouldStop check is eliminated" + } s""" |if ($batch == null) { | $nextBatch(); |} |while ($batch != null) { - | int numRows = $batch.numRows(); - | while ($idx < numRows) { - | int $rowidx = $idx++; + | int $numRows = $batch.numRows(); + | int $localEnd = $numRows - $idx; + | for (int $localIdx = 0; $localIdx < $localEnd; $localIdx++) { + | int $rowidx = $idx + $localIdx; | ${consume(ctx, columnsBatchInput).trim} - | if (shouldStop()) return; + | $shouldStop | } + | $idx = $numRows; | $batch = null; | $nextBatch(); |} From 362ee93296a0de6342b4339e941e6a11f445c5b2 Mon Sep 17 00:00:00 2001 From: Juan Rodriguez Hortala Date: Sun, 26 Mar 2017 10:39:05 +0100 Subject: [PATCH 135/512] logging improvements ## What changes were proposed in this pull request? Adding additional information to existing logging messages: - YarnAllocator: log the executor ID together with the container id when a container for an executor is launched. - NettyRpcEnv: log the receiver address when there is a timeout waiting for an answer to a remote call. - ExecutorAllocationManager: fix a typo in the logging message for the list of executors to be removed. ## How was this patch tested? Build spark and submit the word count example to a YARN cluster using cluster mode Author: Juan Rodriguez Hortala Closes #17411 from juanrh/logging-improvements. --- .../scala/org/apache/spark/ExecutorAllocationManager.scala | 2 +- .../main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala | 3 ++- .../scala/org/apache/spark/deploy/yarn/YarnAllocator.scala | 3 ++- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala index 1366251d0618..261b3329a7b9 100644 --- a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala +++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala @@ -439,7 +439,7 @@ private[spark] class ExecutorAllocationManager( executorsRemoved } else { logWarning(s"Unable to reach the cluster manager to kill executor/s " + - "executorIdsToBeRemoved.mkString(\",\") or no executor eligible to kill!") + s"${executorIdsToBeRemoved.mkString(",")} or no executor eligible to kill!") Seq.empty[String] } } diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala index ff5e39a8dcbc..b316e5443f63 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala @@ -236,7 +236,8 @@ private[netty] class NettyRpcEnv( val timeoutCancelable = timeoutScheduler.schedule(new Runnable { override def run(): Unit = { - onFailure(new TimeoutException(s"Cannot receive any reply in ${timeout.duration}")) + onFailure(new TimeoutException(s"Cannot receive any reply from ${remoteAddr} " + + s"in ${timeout.duration}")) } }, timeout.duration.toNanos, TimeUnit.NANOSECONDS) promise.future.onComplete { v => diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala index abd2de75c645..25556763da90 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala @@ -494,7 +494,8 @@ private[yarn] class YarnAllocator( val containerId = container.getId val executorId = executorIdCounter.toString assert(container.getResource.getMemory >= resource.getMemory) - logInfo(s"Launching container $containerId on host $executorHostname") + logInfo(s"Launching container $containerId on host $executorHostname " + + s"for executor with ID $executorId") def updateInternalState(): Unit = synchronized { numExecutorsRunning += 1 From 617ab6445ea33d8297f0691723fd19bae19228dc Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Sun, 26 Mar 2017 22:47:31 +0200 Subject: [PATCH 136/512] [SPARK-20086][SQL] CollapseWindow should not collapse dependent adjacent windows ## What changes were proposed in this pull request? The `CollapseWindow` is currently to aggressive when collapsing adjacent windows. It also collapses windows in the which the parent produces a column that is consumed by the child; this creates an invalid window which will fail at runtime. This PR fixes this by adding a check for dependent adjacent windows to the `CollapseWindow` rule. ## How was this patch tested? Added a new test case to `CollapseWindowSuite` Author: Herman van Hovell Closes #17432 from hvanhovell/SPARK-20086. --- .../spark/sql/catalyst/optimizer/Optimizer.scala | 8 +++++--- .../sql/catalyst/optimizer/CollapseWindowSuite.scala | 11 +++++++++++ 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index ee7de8692149..dbe3ded4bbf1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -597,12 +597,14 @@ object CollapseRepartition extends Rule[LogicalPlan] { /** * Collapse Adjacent Window Expression. - * - If the partition specs and order specs are the same, collapse into the parent. + * - If the partition specs and order specs are the same and the window expression are + * independent, collapse into the parent. */ object CollapseWindow extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { - case w @ Window(we1, ps1, os1, Window(we2, ps2, os2, grandChild)) if ps1 == ps2 && os1 == os2 => - w.copy(windowExpressions = we2 ++ we1, child = grandChild) + case w1 @ Window(we1, ps1, os1, w2 @ Window(we2, ps2, os2, grandChild)) + if ps1 == ps2 && os1 == os2 && w1.references.intersect(w2.windowOutputSet).isEmpty => + w1.copy(windowExpressions = we2 ++ we1, child = grandChild) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseWindowSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseWindowSuite.scala index 3f7d1d9fd99a..52054c2f8bd8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseWindowSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseWindowSuite.scala @@ -78,4 +78,15 @@ class CollapseWindowSuite extends PlanTest { comparePlans(optimized2, correctAnswer2) } + + test("Don't collapse adjacent windows with dependent columns") { + val query = testRelation + .window(Seq(sum(a).as('sum_a)), partitionSpec1, orderSpec1) + .window(Seq(max('sum_a).as('max_sum_a)), partitionSpec1, orderSpec1) + .analyze + + val expected = query.analyze + val optimized = Optimize.execute(query.analyze) + comparePlans(optimized, expected) + } } From 0bc8847aa216497549c78ad49ec7ac066a059b15 Mon Sep 17 00:00:00 2001 From: zero323 Date: Sun, 26 Mar 2017 16:49:27 -0700 Subject: [PATCH 137/512] [SPARK-19281][PYTHON][ML] spark.ml Python API for FPGrowth ## What changes were proposed in this pull request? - Add `HasSupport` and `HasConfidence` `Params`. - Add new module `pyspark.ml.fpm`. - Add `FPGrowth` / `FPGrowthModel` wrappers. - Provide tests for new features. ## How was this patch tested? Unit tests. Author: zero323 Closes #17218 from zero323/SPARK-19281. --- dev/sparktestsupport/modules.py | 5 +- python/docs/pyspark.ml.rst | 8 ++ python/pyspark/ml/fpm.py | 216 ++++++++++++++++++++++++++++++++ python/pyspark/ml/tests.py | 53 ++++++-- 4 files changed, 273 insertions(+), 9 deletions(-) create mode 100644 python/pyspark/ml/fpm.py diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 10ad1fe3aa2c..eaf1f3a1db2f 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -423,15 +423,16 @@ def __hash__(self): "python/pyspark/ml/" ], python_test_goals=[ - "pyspark.ml.feature", "pyspark.ml.classification", "pyspark.ml.clustering", + "pyspark.ml.evaluation", + "pyspark.ml.feature", + "pyspark.ml.fpm", "pyspark.ml.linalg.__init__", "pyspark.ml.recommendation", "pyspark.ml.regression", "pyspark.ml.tuning", "pyspark.ml.tests", - "pyspark.ml.evaluation", ], blacklisted_python_implementations=[ "PyPy" # Skip these tests under PyPy since they require numpy and it isn't available there diff --git a/python/docs/pyspark.ml.rst b/python/docs/pyspark.ml.rst index 26f7415e1a42..a68183445d78 100644 --- a/python/docs/pyspark.ml.rst +++ b/python/docs/pyspark.ml.rst @@ -80,3 +80,11 @@ pyspark.ml.evaluation module :members: :undoc-members: :inherited-members: + +pyspark.ml.fpm module +---------------------------- + +.. automodule:: pyspark.ml.fpm + :members: + :undoc-members: + :inherited-members: diff --git a/python/pyspark/ml/fpm.py b/python/pyspark/ml/fpm.py new file mode 100644 index 000000000000..b30d4edb1990 --- /dev/null +++ b/python/pyspark/ml/fpm.py @@ -0,0 +1,216 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from pyspark import keyword_only, since +from pyspark.ml.util import * +from pyspark.ml.wrapper import JavaEstimator, JavaModel +from pyspark.ml.param.shared import * + +__all__ = ["FPGrowth", "FPGrowthModel"] + + +class HasSupport(Params): + """ + Mixin for param support. + """ + + minSupport = Param( + Params._dummy(), + "minSupport", + """Minimal support level of the frequent pattern. [0.0, 1.0]. + Any pattern that appears more than (minSupport * size-of-the-dataset) + times will be output""", + typeConverter=TypeConverters.toFloat) + + def setMinSupport(self, value): + """ + Sets the value of :py:attr:`minSupport`. + """ + return self._set(minSupport=value) + + def getMinSupport(self): + """ + Gets the value of minSupport or its default value. + """ + return self.getOrDefault(self.minSupport) + + +class HasConfidence(Params): + """ + Mixin for param confidence. + """ + + minConfidence = Param( + Params._dummy(), + "minConfidence", + """Minimal confidence for generating Association Rule. [0.0, 1.0] + Note that minConfidence has no effect during fitting.""", + typeConverter=TypeConverters.toFloat) + + def setMinConfidence(self, value): + """ + Sets the value of :py:attr:`minConfidence`. + """ + return self._set(minConfidence=value) + + def getMinConfidence(self): + """ + Gets the value of minConfidence or its default value. + """ + return self.getOrDefault(self.minConfidence) + + +class HasItemsCol(Params): + """ + Mixin for param itemsCol: items column name. + """ + + itemsCol = Param(Params._dummy(), "itemsCol", + "items column name", typeConverter=TypeConverters.toString) + + def setItemsCol(self, value): + """ + Sets the value of :py:attr:`itemsCol`. + """ + return self._set(itemsCol=value) + + def getItemsCol(self): + """ + Gets the value of itemsCol or its default value. + """ + return self.getOrDefault(self.itemsCol) + + +class FPGrowthModel(JavaModel, JavaMLWritable, JavaMLReadable): + """ + .. note:: Experimental + + Model fitted by FPGrowth. + + .. versionadded:: 2.2.0 + """ + @property + @since("2.2.0") + def freqItemsets(self): + """ + DataFrame with two columns: + * `items` - Itemset of the same type as the input column. + * `freq` - Frequency of the itemset (`LongType`). + """ + return self._call_java("freqItemsets") + + @property + @since("2.2.0") + def associationRules(self): + """ + Data with three columns: + * `antecedent` - Array of the same type as the input column. + * `consequent` - Array of the same type as the input column. + * `confidence` - Confidence for the rule (`DoubleType`). + """ + return self._call_java("associationRules") + + +class FPGrowth(JavaEstimator, HasItemsCol, HasPredictionCol, + HasSupport, HasConfidence, JavaMLWritable, JavaMLReadable): + """ + .. note:: Experimental + + A parallel FP-growth algorithm to mine frequent itemsets. The algorithm is described in + Li et al., PFP: Parallel FP-Growth for Query Recommendation [LI2008]_. + PFP distributes computation in such a way that each worker executes an + independent group of mining tasks. The FP-Growth algorithm is described in + Han et al., Mining frequent patterns without candidate generation [HAN2000]_ + + .. [LI2008] http://dx.doi.org/10.1145/1454008.1454027 + .. [HAN2000] http://dx.doi.org/10.1145/335191.335372 + + .. note:: null values in the feature column are ignored during fit(). + .. note:: Internally `transform` `collects` and `broadcasts` association rules. + + >>> from pyspark.sql.functions import split + >>> data = (spark.read + ... .text("data/mllib/sample_fpgrowth.txt") + ... .select(split("value", "\s+").alias("items"))) + >>> data.show(truncate=False) + +------------------------+ + |items | + +------------------------+ + |[r, z, h, k, p] | + |[z, y, x, w, v, u, t, s]| + |[s, x, o, n, r] | + |[x, z, y, m, t, s, q, e]| + |[z] | + |[x, z, y, r, q, t, p] | + +------------------------+ + >>> fp = FPGrowth(minSupport=0.2, minConfidence=0.7) + >>> fpm = fp.fit(data) + >>> fpm.freqItemsets.show(5) + +---------+----+ + | items|freq| + +---------+----+ + | [s]| 3| + | [s, x]| 3| + |[s, x, z]| 2| + | [s, z]| 2| + | [r]| 3| + +---------+----+ + only showing top 5 rows + >>> fpm.associationRules.show(5) + +----------+----------+----------+ + |antecedent|consequent|confidence| + +----------+----------+----------+ + | [t, s]| [y]| 1.0| + | [t, s]| [x]| 1.0| + | [t, s]| [z]| 1.0| + | [p]| [r]| 1.0| + | [p]| [z]| 1.0| + +----------+----------+----------+ + only showing top 5 rows + >>> new_data = spark.createDataFrame([(["t", "s"], )], ["items"]) + >>> sorted(fpm.transform(new_data).first().prediction) + ['x', 'y', 'z'] + + .. versionadded:: 2.2.0 + """ + @keyword_only + def __init__(self, minSupport=0.3, minConfidence=0.8, itemsCol="items", + predictionCol="prediction", numPartitions=None): + """ + __init__(self, minSupport=0.3, minConfidence=0.8, itemsCol="items", \ + predictionCol="prediction", numPartitions=None) + """ + super(FPGrowth, self).__init__() + self._java_obj = self._new_java_obj("org.apache.spark.ml.fpm.FPGrowth", self.uid) + self._setDefault(minSupport=0.3, minConfidence=0.8, + itemsCol="items", predictionCol="prediction") + kwargs = self._input_kwargs + self.setParams(**kwargs) + + @keyword_only + @since("2.2.0") + def setParams(self, minSupport=0.3, minConfidence=0.8, itemsCol="items", + predictionCol="prediction", numPartitions=None): + """ + setParams(self, minSupport=0.3, minConfidence=0.8, itemsCol="items", \ + predictionCol="prediction", numPartitions=None) + """ + kwargs = self._input_kwargs + return self._set(**kwargs) + + def _create_model(self, java_model): + return FPGrowthModel(java_model) diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index cc559db58720..527db9b66793 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -42,7 +42,7 @@ import array as pyarray import numpy as np from numpy import ( - array, array_equal, zeros, inf, random, exp, dot, all, mean, abs, arange, tile, ones) + abs, all, arange, array, array_equal, dot, exp, inf, mean, ones, random, tile, zeros) from numpy import sum as array_sum import inspect @@ -50,18 +50,20 @@ from pyspark.ml import Estimator, Model, Pipeline, PipelineModel, Transformer from pyspark.ml.classification import * from pyspark.ml.clustering import * +from pyspark.ml.common import _java2py, _py2java from pyspark.ml.evaluation import BinaryClassificationEvaluator, RegressionEvaluator from pyspark.ml.feature import * -from pyspark.ml.linalg import Vector, SparseVector, DenseVector, VectorUDT,\ - DenseMatrix, SparseMatrix, Vectors, Matrices, MatrixUDT, _convert_to_vector +from pyspark.ml.fpm import FPGrowth, FPGrowthModel +from pyspark.ml.linalg import ( + DenseMatrix, DenseMatrix, DenseVector, Matrices, MatrixUDT, + SparseMatrix, SparseVector, Vector, VectorUDT, Vectors, _convert_to_vector) from pyspark.ml.param import Param, Params, TypeConverters -from pyspark.ml.param.shared import HasMaxIter, HasInputCol, HasSeed +from pyspark.ml.param.shared import HasInputCol, HasMaxIter, HasSeed from pyspark.ml.recommendation import ALS -from pyspark.ml.regression import LinearRegression, DecisionTreeRegressor, \ - GeneralizedLinearRegression +from pyspark.ml.regression import ( + DecisionTreeRegressor, GeneralizedLinearRegression, LinearRegression) from pyspark.ml.tuning import * from pyspark.ml.wrapper import JavaParams, JavaWrapper -from pyspark.ml.common import _java2py, _py2java from pyspark.serializers import PickleSerializer from pyspark.sql import DataFrame, Row, SparkSession from pyspark.sql.functions import rand @@ -1243,6 +1245,43 @@ def test_tweedie_distribution(self): self.assertTrue(np.isclose(model2.intercept, 0.6667, atol=1E-4)) +class FPGrowthTests(SparkSessionTestCase): + def setUp(self): + super(FPGrowthTests, self).setUp() + self.data = self.spark.createDataFrame( + [([1, 2], ), ([1, 2], ), ([1, 2, 3], ), ([1, 3], )], + ["items"]) + + def test_association_rules(self): + fp = FPGrowth() + fpm = fp.fit(self.data) + + expected_association_rules = self.spark.createDataFrame( + [([3], [1], 1.0), ([2], [1], 1.0)], + ["antecedent", "consequent", "confidence"] + ) + actual_association_rules = fpm.associationRules + + self.assertEqual(actual_association_rules.subtract(expected_association_rules).count(), 0) + self.assertEqual(expected_association_rules.subtract(actual_association_rules).count(), 0) + + def test_freq_itemsets(self): + fp = FPGrowth() + fpm = fp.fit(self.data) + + expected_freq_itemsets = self.spark.createDataFrame( + [([1], 4), ([2], 3), ([2, 1], 3), ([3], 2), ([3, 1], 2)], + ["items", "freq"] + ) + actual_freq_itemsets = fpm.freqItemsets + + self.assertEqual(actual_freq_itemsets.subtract(expected_freq_itemsets).count(), 0) + self.assertEqual(expected_freq_itemsets.subtract(actual_freq_itemsets).count(), 0) + + def tearDown(self): + del self.data + + class ALSTest(SparkSessionTestCase): def test_storage_levels(self): From 3fbf0a5f9297f438bc92db11f106d4a0ae568613 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Sun, 26 Mar 2017 18:40:00 -0700 Subject: [PATCH 138/512] [MINOR][DOCS] Match several documentation changes in Scala to R/Python ## What changes were proposed in this pull request? This PR proposes to match minor documentations changes in https://github.com/apache/spark/pull/17399 and https://github.com/apache/spark/pull/17380 to R/Python. ## How was this patch tested? Manual tests in Python , Python tests via `./python/run-tests.py --module=pyspark-sql` and lint-checks for Python/R. Author: hyukjinkwon Closes #17429 from HyukjinKwon/minor-match-doc. --- R/pkg/R/functions.R | 6 +++--- python/pyspark/sql/functions.py | 8 ++++---- python/pyspark/sql/tests.py | 8 ++++++++ 3 files changed, 15 insertions(+), 7 deletions(-) diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 2cff3ac08c3a..449476dec533 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -2632,8 +2632,8 @@ setMethod("date_sub", signature(y = "Column", x = "numeric"), #' format_number #' -#' Formats numeric column y to a format like '#,###,###.##', rounded to x decimal places, -#' and returns the result as a string column. +#' Formats numeric column y to a format like '#,###,###.##', rounded to x decimal places +#' with HALF_EVEN round mode, and returns the result as a string column. #' #' If x is 0, the result has no decimal point or fractional part. #' If x < 0, the result will be null. @@ -3548,7 +3548,7 @@ setMethod("row_number", #' array_contains #' -#' Returns true if the array contain the value. +#' Returns null if the array is null, true if the array contains the value, and false otherwise. #' #' @param x A Column #' @param value A value to be checked if contained in the column diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index f9121e60f35b..843ae3816f06 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1327,8 +1327,8 @@ def encode(col, charset): @since(1.5) def format_number(col, d): """ - Formats the number X to a format like '#,--#,--#.--', rounded to d decimal places, - and returns the result as a string. + Formats the number X to a format like '#,--#,--#.--', rounded to d decimal places + with HALF_EVEN round mode, and returns the result as a string. :param col: the column name of the numeric value to be formatted :param d: the N decimal places @@ -1675,8 +1675,8 @@ def array(*cols): @since(1.5) def array_contains(col, value): """ - Collection function: returns True if the array contains the given value. The collection - elements and value must be of the same type. + Collection function: returns null if the array is null, true if the array contains the + given value, and false otherwise. :param col: name of column containing array :param value: value to check for in array diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index b93b7ed19210..db41b4edb6dd 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -1129,6 +1129,14 @@ def test_rand_functions(self): rndn2 = df.select('key', functions.randn(0)).collect() self.assertEqual(sorted(rndn1), sorted(rndn2)) + def test_array_contains_function(self): + from pyspark.sql.functions import array_contains + + df = self.spark.createDataFrame([(["1", "2", "3"],), ([],)], ['data']) + actual = df.select(array_contains(df.data, 1).alias('b')).collect() + # The value argument can be implicitly castable to the element's type of the array. + self.assertEqual([Row(b=True), Row(b=False)], actual) + def test_between_function(self): df = self.sc.parallelize([ Row(a=1, b=2, c=3), From 890493458de396cfcffdd71233cfdd39e834944b Mon Sep 17 00:00:00 2001 From: wangzhenhua Date: Mon, 27 Mar 2017 23:41:27 +0800 Subject: [PATCH 139/512] [SPARK-20104][SQL] Don't estimate IsNull or IsNotNull predicates for non-leaf node ## What changes were proposed in this pull request? In current stage, we don't have advanced statistics such as sketches or histograms. As a result, some operator can't estimate `nullCount` accurately. E.g. left outer join estimation does not accurately update `nullCount` currently. So for `IsNull` and `IsNotNull` predicates, we only estimate them when the child is a leaf node, whose `nullCount` is accurate. ## How was this patch tested? A new test case is added in `FilterEstimationSuite`. Author: wangzhenhua Closes #17438 from wzhfy/nullEstimation. --- .../statsEstimation/FilterEstimation.scala | 12 ++++++--- .../FilterEstimationSuite.scala | 25 ++++++++++++++++++- 2 files changed, 33 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala index b10785b05d6c..f14df93160b7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala @@ -24,7 +24,7 @@ import scala.math.BigDecimal.RoundingMode import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.CatalystConf import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Filter, Statistics} +import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Filter, LeafNode, Statistics} import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ @@ -174,10 +174,16 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo case InSet(ar: Attribute, set) => evaluateInSet(ar, set, update) - case IsNull(ar: Attribute) => + // In current stage, we don't have advanced statistics such as sketches or histograms. + // As a result, some operator can't estimate `nullCount` accurately. E.g. left outer join + // estimation does not accurately update `nullCount` currently. + // So for IsNull and IsNotNull predicates, we only estimate them when the child is a leaf + // node, whose `nullCount` is accurate. + // This is a limitation due to lack of advanced stats. We should remove it in the future. + case IsNull(ar: Attribute) if plan.child.isInstanceOf[LeafNode] => evaluateNullCheck(ar, isNull = true, update) - case IsNotNull(ar: Attribute) => + case IsNotNull(ar: Attribute) if plan.child.isInstanceOf[LeafNode] => evaluateNullCheck(ar, isNull = false, update) case _ => diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala index 4691913c8c98..07abe1ed2853 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala @@ -20,7 +20,8 @@ package org.apache.spark.sql.catalyst.statsEstimation import java.sql.Date import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Filter, Statistics} +import org.apache.spark.sql.catalyst.plans.LeftOuter +import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Filter, Join, Statistics} import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils._ import org.apache.spark.sql.types._ @@ -340,6 +341,28 @@ class FilterEstimationSuite extends StatsEstimationTestBase { expectedRowCount = 2) } + // This is a limitation test. We should remove it after the limitation is removed. + test("don't estimate IsNull or IsNotNull if the child is a non-leaf node") { + val attrIntLargerRange = AttributeReference("c1", IntegerType)() + val colStatIntLargerRange = ColumnStat(distinctCount = 20, min = Some(1), max = Some(20), + nullCount = 10, avgLen = 4, maxLen = 4) + val smallerTable = childStatsTestPlan(Seq(attrInt), 10L) + val largerTable = StatsTestPlan( + outputList = Seq(attrIntLargerRange), + rowCount = 30, + attributeStats = AttributeMap(Seq(attrIntLargerRange -> colStatIntLargerRange))) + val nonLeafChild = Join(largerTable, smallerTable, LeftOuter, + Some(EqualTo(attrIntLargerRange, attrInt))) + + Seq(IsNull(attrIntLargerRange), IsNotNull(attrIntLargerRange)).foreach { predicate => + validateEstimatedStats( + Filter(predicate, nonLeafChild), + // column stats don't change + Seq(attrInt -> colStatInt, attrIntLargerRange -> colStatIntLargerRange), + expectedRowCount = 30) + } + } + private def childStatsTestPlan(outList: Seq[Attribute], tableRowCount: BigInt): StatsTestPlan = { StatsTestPlan( outputList = outList, From 0588dc7c0a9f3180dddae0dc202a6d41eb43464f Mon Sep 17 00:00:00 2001 From: Hossein Date: Mon, 27 Mar 2017 08:53:45 -0700 Subject: [PATCH 140/512] [SPARK-20088] Do not create new SparkContext in SparkR createSparkContext ## What changes were proposed in this pull request? Instead of creating new `JavaSparkContext` we use `SparkContext.getOrCreate`. ## How was this patch tested? Existing tests Author: Hossein Closes #17423 from falaki/SPARK-20088. --- core/src/main/scala/org/apache/spark/api/r/RRDD.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala index 72ae0340aa3d..295355c7bf01 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala @@ -136,7 +136,7 @@ private[r] object RRDD { .mkString(File.separator)) } - val jsc = new JavaSparkContext(sparkConf) + val jsc = new JavaSparkContext(SparkContext.getOrCreate(sparkConf)) jars.foreach { jar => jsc.addJar(jar) } From 314cf51ded52834cfbaacf58d3d05a220965ca2a Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 27 Mar 2017 10:23:28 -0700 Subject: [PATCH 141/512] [SPARK-20102] Fix nightly packaging and RC packaging scripts w/ two minor build fixes ## What changes were proposed in this pull request? The master snapshot publisher builds are currently broken due to two minor build issues: 1. For unknown reasons, the LFTP `mkdir -p` command began throwing errors when the remote directory already exists. This change of behavior might have been caused by configuration changes in the ASF's SFTP server, but I'm not entirely sure of that. To work around this problem, this patch updates the script to ignore errors from the `lftp mkdir -p` commands. 2. The PySpark `setup.py` file references a non-existent `pyspark.ml.stat` module, causing Python packaging to fail by complaining about a missing directory. The fix is to simply drop that line from the setup script. ## How was this patch tested? The LFTP fix was tested by manually running the failing commands on AMPLab Jenkins against the ASF SFTP server. The PySpark fix was tested locally. Author: Josh Rosen Closes #17437 from JoshRosen/spark-20102. --- dev/create-release/release-build.sh | 8 ++++---- python/setup.py | 1 - 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/dev/create-release/release-build.sh b/dev/create-release/release-build.sh index e1db997a7d41..7976d8a03954 100755 --- a/dev/create-release/release-build.sh +++ b/dev/create-release/release-build.sh @@ -246,7 +246,7 @@ if [[ "$1" == "package" ]]; then dest_dir="$REMOTE_PARENT_DIR/${DEST_DIR_NAME}-bin" echo "Copying release tarballs to $dest_dir" # Put to new directory: - LFTP mkdir -p $dest_dir + LFTP mkdir -p $dest_dir || true LFTP mput -O $dest_dir 'spark-*' LFTP mput -O $dest_dir 'pyspark-*' LFTP mput -O $dest_dir 'SparkR_*' @@ -254,7 +254,7 @@ if [[ "$1" == "package" ]]; then LFTP "rm -r -f $REMOTE_PARENT_DIR/latest || exit 0" LFTP mv $dest_dir "$REMOTE_PARENT_DIR/latest" # Re-upload a second time and leave the files in the timestamped upload directory: - LFTP mkdir -p $dest_dir + LFTP mkdir -p $dest_dir || true LFTP mput -O $dest_dir 'spark-*' LFTP mput -O $dest_dir 'pyspark-*' LFTP mput -O $dest_dir 'SparkR_*' @@ -271,13 +271,13 @@ if [[ "$1" == "docs" ]]; then PRODUCTION=1 RELEASE_VERSION="$SPARK_VERSION" jekyll build echo "Copying release documentation to $dest_dir" # Put to new directory: - LFTP mkdir -p $dest_dir + LFTP mkdir -p $dest_dir || true LFTP mirror -R _site $dest_dir # Delete /latest directory and rename new upload to /latest LFTP "rm -r -f $REMOTE_PARENT_DIR/latest || exit 0" LFTP mv $dest_dir "$REMOTE_PARENT_DIR/latest" # Re-upload a second time and leave the files in the timestamped upload directory: - LFTP mkdir -p $dest_dir + LFTP mkdir -p $dest_dir || true LFTP mirror -R _site $dest_dir cd .. exit 0 diff --git a/python/setup.py b/python/setup.py index 47eab98e0f7b..f50035435e26 100644 --- a/python/setup.py +++ b/python/setup.py @@ -167,7 +167,6 @@ def _supports_symlinks(): 'pyspark.ml', 'pyspark.ml.linalg', 'pyspark.ml.param', - 'pyspark.ml.stat', 'pyspark.sql', 'pyspark.streaming', 'pyspark.bin', From 3fada2f502107bd5572fb895471943de7b2c38e4 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Mon, 27 Mar 2017 10:43:00 -0700 Subject: [PATCH 142/512] [SPARK-20105][TESTS][R] Add tests for checkType and type string in structField in R ## What changes were proposed in this pull request? It seems `checkType` and the type string in `structField` are not being tested closely. This string format currently seems SparkR-specific (see https://github.com/apache/spark/blob/d1f6c64c4b763c05d6d79ae5497f298dc3835f3e/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala#L93-L131) but resembles SQL type definition. Therefore, it seems nicer if we test positive/negative cases in R side. ## How was this patch tested? Unit tests in `test_sparkSQL.R`. Author: hyukjinkwon Closes #17439 from HyukjinKwon/r-typestring-tests. --- R/pkg/inst/tests/testthat/test_sparkSQL.R | 53 +++++++++++++++++++++++ 1 file changed, 53 insertions(+) diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index 394d1a04e09c..5acf8719d120 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -140,6 +140,59 @@ test_that("structType and structField", { expect_equal(testSchema$fields()[[1]]$dataType.toString(), "StringType") }) +test_that("structField type strings", { + # positive cases + primitiveTypes <- list(byte = "ByteType", + integer = "IntegerType", + float = "FloatType", + double = "DoubleType", + string = "StringType", + binary = "BinaryType", + boolean = "BooleanType", + timestamp = "TimestampType", + date = "DateType") + + complexTypes <- list("map" = "MapType(StringType,IntegerType,true)", + "array" = "ArrayType(StringType,true)", + "struct" = "StructType(StructField(a,StringType,true))") + + typeList <- c(primitiveTypes, complexTypes) + typeStrings <- names(typeList) + + for (i in seq_along(typeStrings)){ + typeString <- typeStrings[i] + expected <- typeList[[i]] + testField <- structField("_col", typeString) + expect_is(testField, "structField") + expect_true(testField$nullable()) + expect_equal(testField$dataType.toString(), expected) + } + + # negative cases + primitiveErrors <- list(Byte = "Byte", + INTEGER = "INTEGER", + numeric = "numeric", + character = "character", + raw = "raw", + logical = "logical") + + complexErrors <- list("map" = " integer", + "array" = "String", + "struct" = "string ", + "map " = "map ", + "array< string>" = " string", + "struct" = " string") + + errorList <- c(primitiveErrors, complexErrors) + typeStrings <- names(errorList) + + for (i in seq_along(typeStrings)){ + typeString <- typeStrings[i] + expected <- paste0("Unsupported type for SparkDataframe: ", errorList[[i]]) + expect_error(structField("_col", typeString), expected) + } +}) + test_that("create DataFrame from RDD", { rdd <- lapply(parallelize(sc, 1:10), function(x) { list(x, as.character(x)) }) df <- createDataFrame(rdd, list("a", "b")) From 1d00761b9176a1f42976057ca78638c5b0763abc Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Mon, 27 Mar 2017 17:37:24 -0700 Subject: [PATCH 143/512] [MINOR][SPARKR] Move 'Data type mapping between R and Spark' to right place in SparkR doc. Section ```Data type mapping between R and Spark``` was put in the wrong place in SparkR doc currently, we should move it to a separate section. ## What changes were proposed in this pull request? Before this PR: ![image](https://cloud.githubusercontent.com/assets/1962026/24340911/bc01a532-126a-11e7-9a08-0d60d13a547c.png) After this PR: ![image](https://cloud.githubusercontent.com/assets/1962026/24340938/d9d32a9a-126a-11e7-8891-d2f5b46e0c71.png) Author: Yanbo Liang Closes #17440 from yanboliang/sparkr-doc. --- docs/sparkr.md | 138 ++++++++++++++++++++++++------------------------- 1 file changed, 69 insertions(+), 69 deletions(-) diff --git a/docs/sparkr.md b/docs/sparkr.md index d7ffd9b3f122..a1a35a7757e5 100644 --- a/docs/sparkr.md +++ b/docs/sparkr.md @@ -394,75 +394,6 @@ head(result[order(result$max_eruption, decreasing = TRUE), ]) {% endhighlight %}
      -#### Data type mapping between R and Spark - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
      RSpark
      bytebyte
      integerinteger
      floatfloat
      doubledouble
      numericdouble
      characterstring
      stringstring
      binarybinary
      rawbinary
      logicalboolean
      POSIXcttimestamp
      POSIXlttimestamp
      Datedate
      arrayarray
      listarray
      envmap
      - #### Run local R functions distributed using `spark.lapply` ##### spark.lapply @@ -557,6 +488,75 @@ SparkR supports a subset of the available R formula operators for model fitting, The following example shows how to save/load a MLlib model by SparkR. {% include_example read_write r/ml/ml.R %} +# Data type mapping between R and Spark + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
      RSpark
      bytebyte
      integerinteger
      floatfloat
      doubledouble
      numericdouble
      characterstring
      stringstring
      binarybinary
      rawbinary
      logicalboolean
      POSIXcttimestamp
      POSIXlttimestamp
      Datedate
      arrayarray
      listarray
      envmap
      + # R Function Name Conflicts When loading and attaching a new package in R, it is possible to have a name [conflict](https://stat.ethz.ch/R-manual/R-devel/library/base/html/library.html), where a From a250933c625ed720d15a0e479e9c51113605b102 Mon Sep 17 00:00:00 2001 From: Shubham Chopra Date: Tue, 28 Mar 2017 09:47:29 +0800 Subject: [PATCH 144/512] [SPARK-19803][CORE][TEST] Proactive replication test failures ## What changes were proposed in this pull request? Executors cache a list of their peers that is refreshed by default every minute. The cached stale references were randomly being used for replication. Since those executors were removed from the master, they did not occur in the block locations as reported by the master. This was fixed by 1. Refreshing peer cache in the block manager before trying to pro-actively replicate. This way the probability of replicating to a failed executor is eliminated. 2. Explicitly stopping the block manager in the tests. This shuts down the RPC endpoint use by the block manager. This way, even if a block manager tries to replicate using a stale reference, the replication logic should take care of refreshing the list of peers after failure. ## How was this patch tested? Tested manually Author: Shubham Chopra Author: Kay Ousterhout Author: Shubham Chopra Closes #17325 from shubhamchopra/SPARK-19803. --- .../spark/storage/BlockInfoManager.scala | 6 ++++ .../apache/spark/storage/BlockManager.scala | 6 +++- .../BlockManagerReplicationSuite.scala | 29 ++++++++++++------- 3 files changed, 29 insertions(+), 12 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockInfoManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockInfoManager.scala index 490d45d12b8e..3db59837fbeb 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockInfoManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockInfoManager.scala @@ -371,6 +371,12 @@ private[storage] class BlockInfoManager extends Logging { blocksWithReleasedLocks } + /** Returns the number of locks held by the given task. Used only for testing. */ + private[storage] def getTaskLockCount(taskAttemptId: TaskAttemptId): Int = { + readLocksByTask.get(taskAttemptId).map(_.size()).getOrElse(0) + + writeLocksByTask.get(taskAttemptId).map(_.size).getOrElse(0) + } + /** * Returns the number of blocks tracked. */ diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 245d94ac4f8b..991346a40af4 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -1187,7 +1187,7 @@ private[spark] class BlockManager( blockId: BlockId, existingReplicas: Set[BlockManagerId], maxReplicas: Int): Unit = { - logInfo(s"Pro-actively replicating $blockId") + logInfo(s"Using $blockManagerId to pro-actively replicate $blockId") blockInfoManager.lockForReading(blockId).foreach { info => val data = doGetLocalBytes(blockId, info) val storageLevel = StorageLevel( @@ -1196,9 +1196,13 @@ private[spark] class BlockManager( useOffHeap = info.level.useOffHeap, deserialized = info.level.deserialized, replication = maxReplicas) + // we know we are called as a result of an executor removal, so we refresh peer cache + // this way, we won't try to replicate to a missing executor with a stale reference + getPeers(forceFetch = true) try { replicate(blockId, data, storageLevel, info.classTag, existingReplicas) } finally { + logDebug(s"Releasing lock for $blockId") releaseLock(blockId) } } diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala index d907add920c8..d5715f8469f7 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala @@ -493,27 +493,34 @@ class BlockManagerProactiveReplicationSuite extends BlockManagerReplicationBehav assert(blockLocations.size === replicationFactor) // remove a random blockManager - val executorsToRemove = blockLocations.take(replicationFactor - 1) + val executorsToRemove = blockLocations.take(replicationFactor - 1).toSet logInfo(s"Removing $executorsToRemove") - executorsToRemove.foreach{exec => - master.removeExecutor(exec.executorId) + initialStores.filter(bm => executorsToRemove.contains(bm.blockManagerId)).foreach { bm => + master.removeExecutor(bm.blockManagerId.executorId) + bm.stop() // giving enough time for replication to happen and new block be reported to master - Thread.sleep(200) + eventually(timeout(5 seconds), interval(100 millis)) { + val newLocations = master.getLocations(blockId).toSet + assert(newLocations.size === replicationFactor) + } } - val newLocations = eventually(timeout(5 seconds), interval(10 millis)) { + val newLocations = eventually(timeout(5 seconds), interval(100 millis)) { val _newLocations = master.getLocations(blockId).toSet assert(_newLocations.size === replicationFactor) _newLocations } logInfo(s"New locations : $newLocations") - // there should only be one common block manager between initial and new locations - assert(newLocations.intersect(blockLocations.toSet).size === 1) - // check if all the read locks have been released - initialStores.filter(bm => newLocations.contains(bm.blockManagerId)).foreach { bm => - val locks = bm.releaseAllLocksForTask(BlockInfo.NON_TASK_WRITER) - assert(locks.size === 0, "Read locks unreleased!") + // new locations should not contain stopped block managers + assert(newLocations.forall(bmId => !executorsToRemove.contains(bmId)), + "New locations contain stopped block managers.") + + // Make sure all locks have been released. + eventually(timeout(1000 milliseconds), interval(10 milliseconds)) { + initialStores.filter(bm => newLocations.contains(bm.blockManagerId)).foreach { bm => + assert(bm.blockInfoManager.getTaskLockCount(BlockInfo.NON_TASK_WRITER) === 0) + } } } } From 8a6f33f0483dcee81467e6374a796b5dbd53ea30 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Mon, 27 Mar 2017 19:04:16 -0700 Subject: [PATCH 145/512] [SPARK-19876][SS] Follow up: Refactored BatchCommitLog to simplify logic ## What changes were proposed in this pull request? Existing logic seemingly writes null to the BatchCommitLog, even though it does additional checks to write '{}' (valid json) to the log. This PR simplifies the logic by disallowing use of `log.add(batchId, metadata)` and instead using `log.add(batchId)`. No question of specifying metadata, so no confusion related to null. ## How was this patch tested? Existing tests pass. Author: Tathagata Das Closes #17444 from tdas/SPARK-19876-1. --- .../execution/streaming/BatchCommitLog.scala | 28 +++++++++++-------- .../execution/streaming/HDFSMetadataLog.scala | 1 + .../execution/streaming/StreamExecution.scala | 2 +- 3 files changed, 19 insertions(+), 12 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/BatchCommitLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/BatchCommitLog.scala index fb1a4fb9b12f..a34938f911f7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/BatchCommitLog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/BatchCommitLog.scala @@ -45,33 +45,39 @@ import org.apache.spark.sql.SparkSession class BatchCommitLog(sparkSession: SparkSession, path: String) extends HDFSMetadataLog[String](sparkSession, path) { + import BatchCommitLog._ + + def add(batchId: Long): Unit = { + super.add(batchId, EMPTY_JSON) + } + + override def add(batchId: Long, metadata: String): Boolean = { + throw new UnsupportedOperationException( + "BatchCommitLog does not take any metadata, use 'add(batchId)' instead") + } + override protected def deserialize(in: InputStream): String = { // called inside a try-finally where the underlying stream is closed in the caller val lines = IOSource.fromInputStream(in, UTF_8.name()).getLines() if (!lines.hasNext) { throw new IllegalStateException("Incomplete log file in the offset commit log") } - parseVersion(lines.next().trim, BatchCommitLog.VERSION) - // read metadata - lines.next().trim match { - case BatchCommitLog.SERIALIZED_VOID => null - case metadata => metadata - } + parseVersion(lines.next.trim, VERSION) + EMPTY_JSON } override protected def serialize(metadata: String, out: OutputStream): Unit = { // called inside a try-finally where the underlying stream is closed in the caller - out.write(s"v${BatchCommitLog.VERSION}".getBytes(UTF_8)) + out.write(s"v${VERSION}".getBytes(UTF_8)) out.write('\n') - // write metadata or void - out.write((if (metadata == null) BatchCommitLog.SERIALIZED_VOID else metadata) - .getBytes(UTF_8)) + // write metadata + out.write(EMPTY_JSON.getBytes(UTF_8)) } } object BatchCommitLog { private val VERSION = 1 - private val SERIALIZED_VOID = "{}" + private val EMPTY_JSON = "{}" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala index 60ce64261c4a..46bfc297931f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala @@ -106,6 +106,7 @@ class HDFSMetadataLog[T <: AnyRef : ClassTag](sparkSession: SparkSession, path: * metadata has already been stored, this method will return `false`. */ override def add(batchId: Long, metadata: T): Boolean = { + require(metadata != null, "'null' metadata cannot written to a metadata log") get(batchId).map(_ => false).getOrElse { // Only write metadata when the batch has not yet been written writeBatch(batchId, metadata) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index 34e9262af7cb..5f548172f5ce 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -305,7 +305,7 @@ class StreamExecution( if (dataAvailable) { // Update committed offsets. committedOffsets ++= availableOffsets - batchCommitLog.add(currentBatchId, null) + batchCommitLog.add(currentBatchId) logDebug(s"batch ${currentBatchId} committed") // We'll increase currentBatchId after we complete processing current batch's data currentBatchId += 1 From ea361165e1ddce4d8aa0242ae3e878d7b39f1de2 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Tue, 28 Mar 2017 10:07:24 +0800 Subject: [PATCH 146/512] [SPARK-20100][SQL] Refactor SessionState initialization ## What changes were proposed in this pull request? The current SessionState initialization code path is quite complex. A part of the creation is done in the SessionState companion objects, a part of the creation is one inside the SessionState class, and a part is done by passing functions. This PR refactors this code path, and consolidates SessionState initialization into a builder class. This SessionState will not do any initialization and just becomes a place holder for the various Spark SQL internals. This also lays the ground work for two future improvements: 1. This provides us with a start for removing the `HiveSessionState`. Removing the `HiveSessionState` would also require us to move resource loading into a separate class, and to (re)move metadata hive. 2. This makes it easier to customize the Spark Session. Currently you will need to create a custom version of the builder. I have added hooks to facilitate this. A future step will be to create a semi stable API on top of this. ## How was this patch tested? Existing tests. Author: Herman van Hovell Closes #17433 from hvanhovell/SPARK-20100. --- .../sql/catalyst/catalog/SessionCatalog.scala | 46 +-- .../sql/catalyst/optimizer/Optimizer.scala | 16 +- .../catalog/SessionCatalogSuite.scala | 22 +- .../spark/sql/execution/SparkOptimizer.scala | 12 +- .../spark/sql/execution/SparkPlanner.scala | 11 +- .../streaming/IncrementalExecution.scala | 23 +- .../spark/sql/internal/SessionState.scala | 180 +++-------- .../sql/internal/sessionStateBuilders.scala | 279 ++++++++++++++++++ .../spark/sql/test/TestSQLContext.scala | 23 +- .../spark/sql/hive/HiveSessionCatalog.scala | 76 +---- .../spark/sql/hive/HiveSessionState.scala | 259 +++++++--------- .../apache/spark/sql/hive/test/TestHive.scala | 60 ++-- .../sql/hive/HiveSessionCatalogSuite.scala | 112 ------- 13 files changed, 547 insertions(+), 572 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/internal/sessionStateBuilders.scala delete mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSessionCatalogSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index a469d1245164..72ab07540889 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -54,7 +54,8 @@ class SessionCatalog( functionRegistry: FunctionRegistry, conf: CatalystConf, hadoopConf: Configuration, - parser: ParserInterface) extends Logging { + parser: ParserInterface, + functionResourceLoader: FunctionResourceLoader) extends Logging { import SessionCatalog._ import CatalogTypes.TablePartitionSpec @@ -69,8 +70,8 @@ class SessionCatalog( functionRegistry, conf, new Configuration(), - CatalystSqlParser) - functionResourceLoader = DummyFunctionResourceLoader + CatalystSqlParser, + DummyFunctionResourceLoader) } // For testing only. @@ -90,9 +91,7 @@ class SessionCatalog( // check whether the temporary table or function exists, then, if not, operate on // the corresponding item in the current database. @GuardedBy("this") - protected var currentDb = formatDatabaseName(DEFAULT_DATABASE) - - @volatile var functionResourceLoader: FunctionResourceLoader = _ + protected var currentDb: String = formatDatabaseName(DEFAULT_DATABASE) /** * Checks if the given name conforms the Hive standard ("[a-zA-z_0-9]+"), @@ -1059,9 +1058,6 @@ class SessionCatalog( * by a tuple (resource type, resource uri). */ def loadFunctionResources(resources: Seq[FunctionResource]): Unit = { - if (functionResourceLoader == null) { - throw new IllegalStateException("functionResourceLoader has not yet been initialized") - } resources.foreach(functionResourceLoader.loadResource) } @@ -1259,28 +1255,16 @@ class SessionCatalog( } /** - * Create a new [[SessionCatalog]] with the provided parameters. `externalCatalog` and - * `globalTempViewManager` are `inherited`, while `currentDb` and `tempTables` are copied. + * Copy the current state of the catalog to another catalog. + * + * This function is synchronized on this [[SessionCatalog]] (the source) to make sure the copied + * state is consistent. The target [[SessionCatalog]] is not synchronized, and should not be + * because the target [[SessionCatalog]] should not be published at this point. The caller must + * synchronize on the target if this assumption does not hold. */ - def newSessionCatalogWith( - conf: CatalystConf, - hadoopConf: Configuration, - functionRegistry: FunctionRegistry, - parser: ParserInterface): SessionCatalog = { - val catalog = new SessionCatalog( - externalCatalog, - globalTempViewManager, - functionRegistry, - conf, - hadoopConf, - parser) - - synchronized { - catalog.currentDb = currentDb - // copy over temporary tables - tempTables.foreach(kv => catalog.tempTables.put(kv._1, kv._2)) - } - - catalog + private[sql] def copyStateTo(target: SessionCatalog): Unit = synchronized { + target.currentDb = currentDb + // copy over temporary tables + tempTables.foreach(kv => target.tempTables.put(kv._1, kv._2)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index dbe3ded4bbf1..dbf479d21513 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -17,20 +17,14 @@ package org.apache.spark.sql.catalyst.optimizer -import scala.annotation.tailrec -import scala.collection.immutable.HashSet import scala.collection.mutable -import scala.collection.mutable.ArrayBuffer -import org.apache.spark.api.java.function.FilterFunction import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.{CatalystConf, SimpleCatalystConf} import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral} -import org.apache.spark.sql.catalyst.planning.ExtractFiltersAndInnerJoins import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ @@ -79,7 +73,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf) Batch("Aggregate", fixedPoint, RemoveLiteralFromGroupExpressions, RemoveRepetitionFromGroupExpressions) :: - Batch("Operator Optimizations", fixedPoint, + Batch("Operator Optimizations", fixedPoint, Seq( // Operator push down PushProjectionThroughUnion, ReorderJoin(conf), @@ -117,7 +111,8 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf) RemoveRedundantProject, SimplifyCreateStructOps, SimplifyCreateArrayOps, - SimplifyCreateMapOps) :: + SimplifyCreateMapOps) ++ + extendedOperatorOptimizationRules: _*) :: Batch("Check Cartesian Products", Once, CheckCartesianProducts(conf)) :: Batch("Join Reorder", Once, @@ -146,6 +141,11 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf) s.withNewPlan(newPlan) } } + + /** + * Override to provide additional rules for the operator optimization batch. + */ + def extendedOperatorOptimizationRules: Seq[Rule[LogicalPlan]] = Nil } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala index ca4ce1c11707..56bca73a8857 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.catalyst.catalog -import org.apache.hadoop.conf.Configuration - import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.{FunctionIdentifier, SimpleCatalystConf, TableIdentifier} import org.apache.spark.sql.catalyst.analysis._ @@ -1331,17 +1329,15 @@ abstract class SessionCatalogSuite extends PlanTest { } } - test("clone SessionCatalog - temp views") { + test("copy SessionCatalog state - temp views") { withEmptyCatalog { original => val tempTable1 = Range(1, 10, 1, 10) original.createTempView("copytest1", tempTable1, overrideIfExists = false) // check if tables copied over - val clone = original.newSessionCatalogWith( - SimpleCatalystConf(caseSensitiveAnalysis = true), - new Configuration(), - new SimpleFunctionRegistry, - CatalystSqlParser) + val clone = new SessionCatalog(original.externalCatalog) + original.copyStateTo(clone) + assert(original ne clone) assert(clone.getTempView("copytest1") == Some(tempTable1)) @@ -1355,7 +1351,7 @@ abstract class SessionCatalogSuite extends PlanTest { } } - test("clone SessionCatalog - current db") { + test("copy SessionCatalog state - current db") { withEmptyCatalog { original => val db1 = "db1" val db2 = "db2" @@ -1368,11 +1364,9 @@ abstract class SessionCatalogSuite extends PlanTest { original.setCurrentDatabase(db1) // check if current db copied over - val clone = original.newSessionCatalogWith( - SimpleCatalystConf(caseSensitiveAnalysis = true), - new Configuration(), - new SimpleFunctionRegistry, - CatalystSqlParser) + val clone = new SessionCatalog(original.externalCatalog) + original.copyStateTo(clone) + assert(original ne clone) assert(clone.getCurrentDatabase == db1) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala index 981728331d36..2cdfb7a7828c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala @@ -30,9 +30,17 @@ class SparkOptimizer( experimentalMethods: ExperimentalMethods) extends Optimizer(catalog, conf) { - override def batches: Seq[Batch] = super.batches :+ + override def batches: Seq[Batch] = (super.batches :+ Batch("Optimize Metadata Only Query", Once, OptimizeMetadataOnlyQuery(catalog, conf)) :+ Batch("Extract Python UDF from Aggregate", Once, ExtractPythonUDFFromAggregate) :+ - Batch("Prune File Source Table Partitions", Once, PruneFileSourcePartitions) :+ + Batch("Prune File Source Table Partitions", Once, PruneFileSourcePartitions)) ++ + postHocOptimizationBatches :+ Batch("User Provided Optimizers", fixedPoint, experimentalMethods.extraOptimizations: _*) + + /** + * Optimization batches that are executed after the regular optimization batches, but before the + * batch executing the [[ExperimentalMethods]] optimizer rules. This hook can be used to add + * custom optimizer batches to the Spark optimizer. + */ + def postHocOptimizationBatches: Seq[Batch] = Nil } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala index 678241656c01..6566502bd8a8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala @@ -27,13 +27,14 @@ import org.apache.spark.sql.internal.SQLConf class SparkPlanner( val sparkContext: SparkContext, val conf: SQLConf, - val extraStrategies: Seq[Strategy]) + val experimentalMethods: ExperimentalMethods) extends SparkStrategies { def numPartitions: Int = conf.numShufflePartitions def strategies: Seq[Strategy] = - extraStrategies ++ ( + experimentalMethods.extraStrategies ++ + extraPlanningStrategies ++ ( FileSourceStrategy :: DataSourceStrategy :: SpecialLimits :: @@ -42,6 +43,12 @@ class SparkPlanner( InMemoryScans :: BasicOperators :: Nil) + /** + * Override to add extra planning strategies to the planner. These strategies are tried after + * the strategies defined in [[ExperimentalMethods]], and before the regular strategies. + */ + def extraPlanningStrategies: Seq[Strategy] = Nil + override protected def collectPlaceholders(plan: SparkPlan): Seq[(SparkPlan, LogicalPlan)] = { plan.collect { case placeholder @ PlanLater(logicalPlan) => placeholder -> logicalPlan diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index 0f0e4a91f8cc..622e049630db 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -20,8 +20,8 @@ package org.apache.spark.sql.execution.streaming import java.util.concurrent.atomic.AtomicInteger import org.apache.spark.internal.Logging -import org.apache.spark.sql.catalyst.expressions.{CurrentBatchTimestamp, Literal} -import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.{SparkSession, Strategy} +import org.apache.spark.sql.catalyst.expressions.CurrentBatchTimestamp import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.{QueryExecution, SparkPlan, SparkPlanner, UnaryExecNode} @@ -40,20 +40,17 @@ class IncrementalExecution( offsetSeqMetadata: OffsetSeqMetadata) extends QueryExecution(sparkSession, logicalPlan) with Logging { - // TODO: make this always part of planning. - val streamingExtraStrategies = - sparkSession.sessionState.planner.StatefulAggregationStrategy +: - sparkSession.sessionState.planner.FlatMapGroupsWithStateStrategy +: - sparkSession.sessionState.planner.StreamingRelationStrategy +: - sparkSession.sessionState.planner.StreamingDeduplicationStrategy +: - sparkSession.sessionState.experimentalMethods.extraStrategies - // Modified planner with stateful operations. - override def planner: SparkPlanner = - new SparkPlanner( + override val planner: SparkPlanner = new SparkPlanner( sparkSession.sparkContext, sparkSession.sessionState.conf, - streamingExtraStrategies) + sparkSession.sessionState.experimentalMethods) { + override def extraPlanningStrategies: Seq[Strategy] = + StatefulAggregationStrategy :: + FlatMapGroupsWithStateStrategy :: + StreamingRelationStrategy :: + StreamingDeduplicationStrategy :: Nil + } /** * See [SPARK-18339] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala index ce80604bd365..b5b0bb0bfc40 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala @@ -22,22 +22,21 @@ import java.io.File import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path -import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.SparkContext +import org.apache.spark.annotation.{Experimental, InterfaceStability} import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.analysis.{Analyzer, FunctionRegistry} import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.optimizer.Optimizer import org.apache.spark.sql.catalyst.parser.ParserInterface import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution._ -import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.streaming.StreamingQueryManager import org.apache.spark.sql.util.ExecutionListenerManager - /** * A class that holds all session-specific state in a given [[SparkSession]]. + * * @param sparkContext The [[SparkContext]]. * @param sharedState The shared state. * @param conf SQL-specific key-value configurations. @@ -46,9 +45,11 @@ import org.apache.spark.sql.util.ExecutionListenerManager * @param catalog Internal catalog for managing table and database states. * @param sqlParser Parser that extracts expressions, plans, table identifiers etc. from SQL texts. * @param analyzer Logical query plan analyzer for resolving unresolved attributes and relations. - * @param streamingQueryManager Interface to start and stop - * [[org.apache.spark.sql.streaming.StreamingQuery]]s. - * @param queryExecutionCreator Lambda to create a [[QueryExecution]] from a [[LogicalPlan]] + * @param optimizer Logical query plan optimizer. + * @param planner Planner that converts optimized logical plans to physical plans + * @param streamingQueryManager Interface to start and stop streaming queries. + * @param createQueryExecution Function used to create QueryExecution objects. + * @param createClone Function used to create clones of the session state. */ private[sql] class SessionState( sparkContext: SparkContext, @@ -59,8 +60,11 @@ private[sql] class SessionState( val catalog: SessionCatalog, val sqlParser: ParserInterface, val analyzer: Analyzer, + val optimizer: Optimizer, + val planner: SparkPlanner, val streamingQueryManager: StreamingQueryManager, - val queryExecutionCreator: LogicalPlan => QueryExecution) { + createQueryExecution: LogicalPlan => QueryExecution, + createClone: (SparkSession, SessionState) => SessionState) { def newHadoopConf(): Configuration = SessionState.newHadoopConf( sparkContext.hadoopConfiguration, @@ -76,41 +80,12 @@ private[sql] class SessionState( hadoopConf } - /** - * A class for loading resources specified by a function. - */ - val functionResourceLoader: FunctionResourceLoader = { - new FunctionResourceLoader { - override def loadResource(resource: FunctionResource): Unit = { - resource.resourceType match { - case JarResource => addJar(resource.uri) - case FileResource => sparkContext.addFile(resource.uri) - case ArchiveResource => - throw new AnalysisException( - "Archive is not allowed to be loaded. If YARN mode is used, " + - "please use --archives options while calling spark-submit.") - } - } - } - } - /** * Interface exposed to the user for registering user-defined functions. * Note that the user-defined functions must be deterministic. */ val udf: UDFRegistration = new UDFRegistration(functionRegistry) - /** - * Logical query plan optimizer. - */ - val optimizer: Optimizer = new SparkOptimizer(catalog, conf, experimentalMethods) - - /** - * Planner that converts optimized logical plans to physical plans. - */ - def planner: SparkPlanner = - new SparkPlanner(sparkContext, conf, experimentalMethods.extraStrategies) - /** * An interface to register custom [[org.apache.spark.sql.util.QueryExecutionListener]]s * that listen for execution metrics. @@ -120,38 +95,13 @@ private[sql] class SessionState( /** * Get an identical copy of the `SessionState` and associate it with the given `SparkSession` */ - def clone(newSparkSession: SparkSession): SessionState = { - val sparkContext = newSparkSession.sparkContext - val confCopy = conf.clone() - val functionRegistryCopy = functionRegistry.clone() - val sqlParser: ParserInterface = new SparkSqlParser(confCopy) - val catalogCopy = catalog.newSessionCatalogWith( - confCopy, - SessionState.newHadoopConf(sparkContext.hadoopConfiguration, confCopy), - functionRegistryCopy, - sqlParser) - val queryExecutionCreator = (plan: LogicalPlan) => new QueryExecution(newSparkSession, plan) - - SessionState.mergeSparkConf(confCopy, sparkContext.getConf) - - new SessionState( - sparkContext, - newSparkSession.sharedState, - confCopy, - experimentalMethods.clone(), - functionRegistryCopy, - catalogCopy, - sqlParser, - SessionState.createAnalyzer(newSparkSession, catalogCopy, confCopy), - new StreamingQueryManager(newSparkSession), - queryExecutionCreator) - } + def clone(newSparkSession: SparkSession): SessionState = createClone(newSparkSession, this) // ------------------------------------------------------ // Helper methods, partially leftover from pre-2.0 days // ------------------------------------------------------ - def executePlan(plan: LogicalPlan): QueryExecution = queryExecutionCreator(plan) + def executePlan(plan: LogicalPlan): QueryExecution = createQueryExecution(plan) def refreshTable(tableName: String): Unit = { catalog.refreshTable(sqlParser.parseTableIdentifier(tableName)) @@ -179,53 +129,12 @@ private[sql] class SessionState( } } - private[sql] object SessionState { - - def apply(sparkSession: SparkSession): SessionState = { - apply(sparkSession, new SQLConf) - } - - def apply(sparkSession: SparkSession, sqlConf: SQLConf): SessionState = { - val sparkContext = sparkSession.sparkContext - - // Automatically extract all entries and put them in our SQLConf - mergeSparkConf(sqlConf, sparkContext.getConf) - - val functionRegistry = FunctionRegistry.builtin.clone() - - val sqlParser: ParserInterface = new SparkSqlParser(sqlConf) - - val catalog = new SessionCatalog( - sparkSession.sharedState.externalCatalog, - sparkSession.sharedState.globalTempViewManager, - functionRegistry, - sqlConf, - newHadoopConf(sparkContext.hadoopConfiguration, sqlConf), - sqlParser) - - val analyzer: Analyzer = createAnalyzer(sparkSession, catalog, sqlConf) - - val streamingQueryManager: StreamingQueryManager = new StreamingQueryManager(sparkSession) - - val queryExecutionCreator = (plan: LogicalPlan) => new QueryExecution(sparkSession, plan) - - val sessionState = new SessionState( - sparkContext, - sparkSession.sharedState, - sqlConf, - new ExperimentalMethods, - functionRegistry, - catalog, - sqlParser, - analyzer, - streamingQueryManager, - queryExecutionCreator) - // functionResourceLoader needs to access SessionState.addJar, so it cannot be created before - // creating SessionState. Setting `catalog.functionResourceLoader` here is safe since the caller - // cannot use SessionCatalog before we return SessionState. - catalog.functionResourceLoader = sessionState.functionResourceLoader - sessionState + /** + * Create a new [[SessionState]] for the given session. + */ + def apply(session: SparkSession): SessionState = { + new SessionStateBuilder(session).build() } def newHadoopConf(hadoopConf: Configuration, sqlConf: SQLConf): Configuration = { @@ -233,34 +142,33 @@ private[sql] object SessionState { sqlConf.getAllConfs.foreach { case (k, v) => if (v ne null) newHadoopConf.set(k, v) } newHadoopConf } +} - /** - * Create an logical query plan `Analyzer` with rules specific to a non-Hive `SessionState`. - */ - private def createAnalyzer( - sparkSession: SparkSession, - catalog: SessionCatalog, - sqlConf: SQLConf): Analyzer = { - new Analyzer(catalog, sqlConf) { - override val extendedResolutionRules: Seq[Rule[LogicalPlan]] = - new FindDataSourceTable(sparkSession) :: - new ResolveSQLOnFile(sparkSession) :: Nil - - override val postHocResolutionRules: Seq[Rule[LogicalPlan]] = - PreprocessTableCreation(sparkSession) :: - PreprocessTableInsertion(sqlConf) :: - DataSourceAnalysis(sqlConf) :: Nil - - override val extendedCheckRules = Seq(PreWriteCheck, HiveOnlyCheck) - } - } +/** + * Concrete implementation of a [[SessionStateBuilder]]. + */ +@Experimental +@InterfaceStability.Unstable +class SessionStateBuilder( + session: SparkSession, + parentState: Option[SessionState] = None) + extends BaseSessionStateBuilder(session, parentState) { + override protected def newBuilder: NewBuilder = new SessionStateBuilder(_, _) +} - /** - * Extract entries from `SparkConf` and put them in the `SQLConf` - */ - def mergeSparkConf(sqlConf: SQLConf, sparkConf: SparkConf): Unit = { - sparkConf.getAll.foreach { case (k, v) => - sqlConf.setConfString(k, v) +/** + * Session shared [[FunctionResourceLoader]]. + */ +@InterfaceStability.Unstable +class SessionFunctionResourceLoader(session: SparkSession) extends FunctionResourceLoader { + override def loadResource(resource: FunctionResource): Unit = { + resource.resourceType match { + case JarResource => session.sessionState.addJar(resource.uri) + case FileResource => session.sparkContext.addFile(resource.uri) + case ArchiveResource => + throw new AnalysisException( + "Archive is not allowed to be loaded. If YARN mode is used, " + + "please use --archives options while calling spark-submit.") } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/sessionStateBuilders.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/sessionStateBuilders.scala new file mode 100644 index 000000000000..6b5559adb1db --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/sessionStateBuilders.scala @@ -0,0 +1,279 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.internal + +import org.apache.spark.SparkConf +import org.apache.spark.annotation.{Experimental, InterfaceStability} +import org.apache.spark.sql.{ExperimentalMethods, SparkSession, Strategy} +import org.apache.spark.sql.catalyst.analysis.{Analyzer, FunctionRegistry} +import org.apache.spark.sql.catalyst.catalog.SessionCatalog +import org.apache.spark.sql.catalyst.optimizer.Optimizer +import org.apache.spark.sql.catalyst.parser.ParserInterface +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.{QueryExecution, SparkOptimizer, SparkPlanner, SparkSqlParser} +import org.apache.spark.sql.execution.datasources._ +import org.apache.spark.sql.streaming.StreamingQueryManager + +/** + * Builder class that coordinates construction of a new [[SessionState]]. + * + * The builder explicitly defines all components needed by the session state, and creates a session + * state when `build` is called. Components should only be initialized once. This is not a problem + * for most components as they are only used in the `build` function. However some components + * (`conf`, `catalog`, `functionRegistry`, `experimentalMethods` & `sqlParser`) are as dependencies + * for other components and are shared as a result. These components are defined as lazy vals to + * make sure the component is created only once. + * + * A developer can modify the builder by providing custom versions of components, or by using the + * hooks provided for the analyzer, optimizer & planner. There are some dependencies between the + * components (they are documented per dependency), a developer should respect these when making + * modifications in order to prevent initialization problems. + * + * A parent [[SessionState]] can be used to initialize the new [[SessionState]]. The new session + * state will clone the parent sessions state's `conf`, `functionRegistry`, `experimentalMethods` + * and `catalog` fields. Note that the state is cloned when `build` is called, and not before. + */ +@Experimental +@InterfaceStability.Unstable +abstract class BaseSessionStateBuilder( + val session: SparkSession, + val parentState: Option[SessionState] = None) { + type NewBuilder = (SparkSession, Option[SessionState]) => BaseSessionStateBuilder + + /** + * Function that produces a new instance of the SessionStateBuilder. This is used by the + * [[SessionState]]'s clone functionality. Make sure to override this when implementing your own + * [[SessionStateBuilder]]. + */ + protected def newBuilder: NewBuilder + + /** + * Extract entries from `SparkConf` and put them in the `SQLConf` + */ + protected def mergeSparkConf(sqlConf: SQLConf, sparkConf: SparkConf): Unit = { + sparkConf.getAll.foreach { case (k, v) => + sqlConf.setConfString(k, v) + } + } + + /** + * SQL-specific key-value configurations. + * + * These either get cloned from a pre-existing instance or newly created. The conf is always + * merged with its [[SparkConf]]. + */ + protected lazy val conf: SQLConf = { + val conf = parentState.map(_.conf.clone()).getOrElse(new SQLConf) + mergeSparkConf(conf, session.sparkContext.conf) + conf + } + + /** + * Internal catalog managing functions registered by the user. + * + * This either gets cloned from a pre-existing version or cloned from the built-in registry. + */ + protected lazy val functionRegistry: FunctionRegistry = { + parentState.map(_.functionRegistry).getOrElse(FunctionRegistry.builtin).clone() + } + + /** + * Experimental methods that can be used to define custom optimization rules and custom planning + * strategies. + * + * This either gets cloned from a pre-existing version or newly created. + */ + protected lazy val experimentalMethods: ExperimentalMethods = { + parentState.map(_.experimentalMethods.clone()).getOrElse(new ExperimentalMethods) + } + + /** + * Parser that extracts expressions, plans, table identifiers etc. from SQL texts. + * + * Note: this depends on the `conf` field. + */ + protected lazy val sqlParser: ParserInterface = new SparkSqlParser(conf) + + /** + * Catalog for managing table and database states. If there is a pre-existing catalog, the state + * of that catalog (temp tables & current database) will be copied into the new catalog. + * + * Note: this depends on the `conf`, `functionRegistry` and `sqlParser` fields. + */ + protected lazy val catalog: SessionCatalog = { + val catalog = new SessionCatalog( + session.sharedState.externalCatalog, + session.sharedState.globalTempViewManager, + functionRegistry, + conf, + SessionState.newHadoopConf(session.sparkContext.hadoopConfiguration, conf), + sqlParser, + new SessionFunctionResourceLoader(session)) + parentState.foreach(_.catalog.copyStateTo(catalog)) + catalog + } + + /** + * Logical query plan analyzer for resolving unresolved attributes and relations. + * + * Note: this depends on the `conf` and `catalog` fields. + */ + protected def analyzer: Analyzer = new Analyzer(catalog, conf) { + override val extendedResolutionRules: Seq[Rule[LogicalPlan]] = + new FindDataSourceTable(session) +: + new ResolveSQLOnFile(session) +: + customResolutionRules + + override val postHocResolutionRules: Seq[Rule[LogicalPlan]] = + PreprocessTableCreation(session) +: + PreprocessTableInsertion(conf) +: + DataSourceAnalysis(conf) +: + customPostHocResolutionRules + + override val extendedCheckRules: Seq[LogicalPlan => Unit] = + PreWriteCheck +: + HiveOnlyCheck +: + customCheckRules + } + + /** + * Custom resolution rules to add to the Analyzer. Prefer overriding this instead of creating + * your own Analyzer. + * + * Note that this may NOT depend on the `analyzer` function. + */ + protected def customResolutionRules: Seq[Rule[LogicalPlan]] = Nil + + /** + * Custom post resolution rules to add to the Analyzer. Prefer overriding this instead of + * creating your own Analyzer. + * + * Note that this may NOT depend on the `analyzer` function. + */ + protected def customPostHocResolutionRules: Seq[Rule[LogicalPlan]] = Nil + + /** + * Custom check rules to add to the Analyzer. Prefer overriding this instead of creating + * your own Analyzer. + * + * Note that this may NOT depend on the `analyzer` function. + */ + protected def customCheckRules: Seq[LogicalPlan => Unit] = Nil + + /** + * Logical query plan optimizer. + * + * Note: this depends on the `conf`, `catalog` and `experimentalMethods` fields. + */ + protected def optimizer: Optimizer = { + new SparkOptimizer(catalog, conf, experimentalMethods) { + override def extendedOperatorOptimizationRules: Seq[Rule[LogicalPlan]] = + super.extendedOperatorOptimizationRules ++ customOperatorOptimizationRules + } + } + + /** + * Custom operator optimization rules to add to the Optimizer. Prefer overriding this instead + * of creating your own Optimizer. + * + * Note that this may NOT depend on the `optimizer` function. + */ + protected def customOperatorOptimizationRules: Seq[Rule[LogicalPlan]] = Nil + + /** + * Planner that converts optimized logical plans to physical plans. + * + * Note: this depends on the `conf` and `experimentalMethods` fields. + */ + protected def planner: SparkPlanner = { + new SparkPlanner(session.sparkContext, conf, experimentalMethods) { + override def extraPlanningStrategies: Seq[Strategy] = + super.extraPlanningStrategies ++ customPlanningStrategies + } + } + + /** + * Custom strategies to add to the planner. Prefer overriding this instead of creating + * your own Planner. + * + * Note that this may NOT depend on the `planner` function. + */ + protected def customPlanningStrategies: Seq[Strategy] = Nil + + /** + * Create a query execution object. + */ + protected def createQueryExecution: LogicalPlan => QueryExecution = { plan => + new QueryExecution(session, plan) + } + + /** + * Interface to start and stop streaming queries. + */ + protected def streamingQueryManager: StreamingQueryManager = new StreamingQueryManager(session) + + /** + * Function used to make clones of the session state. + */ + protected def createClone: (SparkSession, SessionState) => SessionState = { + val createBuilder = newBuilder + (session, state) => createBuilder(session, Option(state)).build() + } + + /** + * Build the [[SessionState]]. + */ + def build(): SessionState = { + new SessionState( + session.sparkContext, + session.sharedState, + conf, + experimentalMethods, + functionRegistry, + catalog, + sqlParser, + analyzer, + optimizer, + planner, + streamingQueryManager, + createQueryExecution, + createClone) + } +} + +/** + * Helper class for using SessionStateBuilders during tests. + */ +private[sql] trait WithTestConf { self: BaseSessionStateBuilder => + def overrideConfs: Map[String, String] + + override protected lazy val conf: SQLConf = { + val conf = parentState.map(_.conf.clone()).getOrElse { + new SQLConf { + clear() + override def clear(): Unit = { + super.clear() + // Make sure we start with the default test configs even after clear + overrideConfs.foreach { case (key, value) => setConfString(key, value) } + } + } + } + mergeSparkConf(conf, session.sparkContext.conf) + conf + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala index 898a2fb4f329..b01977a23890 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.test import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.internal.{SessionState, SQLConf} +import org.apache.spark.sql.internal.{SessionState, SessionStateBuilder, SQLConf, WithTestConf} /** * A special [[SparkSession]] prepared for testing. @@ -35,16 +35,9 @@ private[sql] class TestSparkSession(sc: SparkContext) extends SparkSession(sc) { } @transient - override lazy val sessionState: SessionState = SessionState( - this, - new SQLConf { - clear() - override def clear(): Unit = { - super.clear() - // Make sure we start with the default test configs even after clear - TestSQLContext.overrideConfs.foreach { case (key, value) => setConfString(key, value) } - } - }) + override lazy val sessionState: SessionState = { + new TestSQLSessionStateBuilder(this, None).build() + } // Needed for Java tests def loadTestData(): Unit = { @@ -67,3 +60,11 @@ private[sql] object TestSQLContext { // Fewer shuffle partitions to speed up testing. SQLConf.SHUFFLE_PARTITIONS.key -> "5") } + +private[sql] class TestSQLSessionStateBuilder( + session: SparkSession, + state: Option[SessionState]) + extends SessionStateBuilder(session, state) with WithTestConf { + override def overrideConfs: Map[String, String] = TestSQLContext.overrideConfs + override def newBuilder: NewBuilder = new TestSQLSessionStateBuilder(_, _) +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala index 6b7599e3d340..2cc20a791d80 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala @@ -25,8 +25,8 @@ import org.apache.hadoop.hive.ql.exec.{UDAF, UDF} import org.apache.hadoop.hive.ql.exec.{FunctionRegistry => HiveFunctionRegistry} import org.apache.hadoop.hive.ql.udf.generic.{AbstractGenericUDAFResolver, GenericUDF, GenericUDTF} -import org.apache.spark.sql.{AnalysisException, SparkSession} -import org.apache.spark.sql.catalyst.{CatalystConf, FunctionIdentifier, TableIdentifier} +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} import org.apache.spark.sql.catalyst.analysis.FunctionRegistry import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder import org.apache.spark.sql.catalyst.catalog.{FunctionResourceLoader, GlobalTempViewManager, SessionCatalog} @@ -47,14 +47,16 @@ private[sql] class HiveSessionCatalog( functionRegistry: FunctionRegistry, conf: SQLConf, hadoopConf: Configuration, - parser: ParserInterface) + parser: ParserInterface, + functionResourceLoader: FunctionResourceLoader) extends SessionCatalog( externalCatalog, globalTempViewManager, functionRegistry, conf, hadoopConf, - parser) { + parser, + functionResourceLoader) { // ---------------------------------------------------------------- // | Methods and fields for interacting with HiveMetastoreCatalog | @@ -69,47 +71,6 @@ private[sql] class HiveSessionCatalog( metastoreCatalog.hiveDefaultTableFilePath(name) } - /** - * Create a new [[HiveSessionCatalog]] with the provided parameters. `externalCatalog` and - * `globalTempViewManager` are `inherited`, while `currentDb` and `tempTables` are copied. - */ - def newSessionCatalogWith( - newSparkSession: SparkSession, - conf: SQLConf, - hadoopConf: Configuration, - functionRegistry: FunctionRegistry, - parser: ParserInterface): HiveSessionCatalog = { - val catalog = HiveSessionCatalog( - newSparkSession, - functionRegistry, - conf, - hadoopConf, - parser) - - synchronized { - catalog.currentDb = currentDb - // copy over temporary tables - tempTables.foreach(kv => catalog.tempTables.put(kv._1, kv._2)) - } - - catalog - } - - /** - * The parent class [[SessionCatalog]] cannot access the [[SparkSession]] class, so we cannot add - * a [[SparkSession]] parameter to [[SessionCatalog.newSessionCatalogWith]]. However, - * [[HiveSessionCatalog]] requires a [[SparkSession]] parameter, so we can a new version of - * `newSessionCatalogWith` and disable this one. - * - * TODO Refactor HiveSessionCatalog to not use [[SparkSession]] directly. - */ - override def newSessionCatalogWith( - conf: CatalystConf, - hadoopConf: Configuration, - functionRegistry: FunctionRegistry, - parser: ParserInterface): HiveSessionCatalog = throw new UnsupportedOperationException( - "to clone HiveSessionCatalog, use the other clone method that also accepts a SparkSession") - // For testing only private[hive] def getCachedDataSourceTable(table: TableIdentifier): LogicalPlan = { val key = metastoreCatalog.getQualifiedTableName(table) @@ -250,28 +211,3 @@ private[sql] class HiveSessionCatalog( "histogram_numeric" ) } - -private[sql] object HiveSessionCatalog { - - def apply( - sparkSession: SparkSession, - functionRegistry: FunctionRegistry, - conf: SQLConf, - hadoopConf: Configuration, - parser: ParserInterface): HiveSessionCatalog = { - // Catalog for handling data source tables. TODO: This really doesn't belong here since it is - // essentially a cache for metastore tables. However, it relies on a lot of session-specific - // things so it would be a lot of work to split its functionality between HiveSessionCatalog - // and HiveCatalog. We should still do it at some point... - val metastoreCatalog = new HiveMetastoreCatalog(sparkSession) - - new HiveSessionCatalog( - sparkSession.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog], - sparkSession.sharedState.globalTempViewManager, - metastoreCatalog, - functionRegistry, - conf, - hadoopConf, - parser) - } -} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala index cb8bcb8591bd..49ff8478f1ae 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala @@ -18,20 +18,23 @@ package org.apache.spark.sql.hive import org.apache.spark.SparkContext +import org.apache.spark.annotation.{Experimental, InterfaceStability} import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.analysis.{Analyzer, FunctionRegistry} +import org.apache.spark.sql.catalyst.optimizer.Optimizer import org.apache.spark.sql.catalyst.parser.ParserInterface import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.execution.{QueryExecution, SparkPlanner, SparkSqlParser} +import org.apache.spark.sql.execution.{QueryExecution, SparkPlanner} import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.hive.client.HiveClient -import org.apache.spark.sql.internal.{SessionState, SharedState, SQLConf} +import org.apache.spark.sql.internal.{BaseSessionStateBuilder, SessionFunctionResourceLoader, SessionState, SharedState, SQLConf} import org.apache.spark.sql.streaming.StreamingQueryManager /** * A class that holds all session-specific state in a given [[SparkSession]] backed by Hive. + * * @param sparkContext The [[SparkContext]]. * @param sharedState The shared state. * @param conf SQL-specific key-value configurations. @@ -40,12 +43,14 @@ import org.apache.spark.sql.streaming.StreamingQueryManager * @param catalog Internal catalog for managing table and database states that uses Hive client for * interacting with the metastore. * @param sqlParser Parser that extracts expressions, plans, table identifiers etc. from SQL texts. - * @param metadataHive The Hive metadata client. * @param analyzer Logical query plan analyzer for resolving unresolved attributes and relations. - * @param streamingQueryManager Interface to start and stop - * [[org.apache.spark.sql.streaming.StreamingQuery]]s. - * @param queryExecutionCreator Lambda to create a [[QueryExecution]] from a [[LogicalPlan]] - * @param plannerCreator Lambda to create a planner that takes into account Hive-specific strategies + * @param optimizer Logical query plan optimizer. + * @param planner Planner that converts optimized logical plans to physical plans and that takes + * Hive-specific strategies into account. + * @param streamingQueryManager Interface to start and stop streaming queries. + * @param createQueryExecution Function used to create QueryExecution objects. + * @param createClone Function used to create clones of the session state. + * @param metadataHive The Hive metadata client. */ private[hive] class HiveSessionState( sparkContext: SparkContext, @@ -55,11 +60,13 @@ private[hive] class HiveSessionState( functionRegistry: FunctionRegistry, override val catalog: HiveSessionCatalog, sqlParser: ParserInterface, - val metadataHive: HiveClient, analyzer: Analyzer, + optimizer: Optimizer, + planner: SparkPlanner, streamingQueryManager: StreamingQueryManager, - queryExecutionCreator: LogicalPlan => QueryExecution, - val plannerCreator: () => SparkPlanner) + createQueryExecution: LogicalPlan => QueryExecution, + createClone: (SparkSession, SessionState) => SessionState, + val metadataHive: HiveClient) extends SessionState( sparkContext, sharedState, @@ -69,14 +76,11 @@ private[hive] class HiveSessionState( catalog, sqlParser, analyzer, + optimizer, + planner, streamingQueryManager, - queryExecutionCreator) { self => - - /** - * Planner that takes into account Hive-specific strategies. - */ - override def planner: SparkPlanner = plannerCreator() - + createQueryExecution, + createClone) { // ------------------------------------------------------ // Helper methods, partially leftover from pre-2.0 days @@ -121,150 +125,115 @@ private[hive] class HiveSessionState( def hiveThriftServerAsync: Boolean = { conf.getConf(HiveUtils.HIVE_THRIFT_SERVER_ASYNC) } +} +private[hive] object HiveSessionState { /** - * Get an identical copy of the `HiveSessionState`. - * This should ideally reuse the `SessionState.clone` but cannot do so. - * Doing that will throw an exception when trying to clone the catalog. + * Create a new [[HiveSessionState]] for the given session. */ - override def clone(newSparkSession: SparkSession): HiveSessionState = { - val sparkContext = newSparkSession.sparkContext - val confCopy = conf.clone() - val functionRegistryCopy = functionRegistry.clone() - val experimentalMethodsCopy = experimentalMethods.clone() - val sqlParser: ParserInterface = new SparkSqlParser(confCopy) - val catalogCopy = catalog.newSessionCatalogWith( - newSparkSession, - confCopy, - SessionState.newHadoopConf(sparkContext.hadoopConfiguration, confCopy), - functionRegistryCopy, - sqlParser) - val queryExecutionCreator = (plan: LogicalPlan) => new QueryExecution(newSparkSession, plan) - - val hiveClient = - newSparkSession.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog].client - .newSession() - - SessionState.mergeSparkConf(confCopy, sparkContext.getConf) - - new HiveSessionState( - sparkContext, - newSparkSession.sharedState, - confCopy, - experimentalMethodsCopy, - functionRegistryCopy, - catalogCopy, - sqlParser, - hiveClient, - HiveSessionState.createAnalyzer(newSparkSession, catalogCopy, confCopy), - new StreamingQueryManager(newSparkSession), - queryExecutionCreator, - HiveSessionState.createPlannerCreator( - newSparkSession, - confCopy, - experimentalMethodsCopy)) + def apply(session: SparkSession): HiveSessionState = { + new HiveSessionStateBuilder(session).build() } - } -private[hive] object HiveSessionState { - - def apply(sparkSession: SparkSession): HiveSessionState = { - apply(sparkSession, new SQLConf) - } - - def apply(sparkSession: SparkSession, conf: SQLConf): HiveSessionState = { - val initHelper = SessionState(sparkSession, conf) - - val sparkContext = sparkSession.sparkContext - - val catalog = HiveSessionCatalog( - sparkSession, - initHelper.functionRegistry, - initHelper.conf, - SessionState.newHadoopConf(sparkContext.hadoopConfiguration, initHelper.conf), - initHelper.sqlParser) - - val metadataHive: HiveClient = - sparkSession.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog].client - .newSession() - - val analyzer: Analyzer = createAnalyzer(sparkSession, catalog, initHelper.conf) +/** + * Builder that produces a [[HiveSessionState]]. + */ +@Experimental +@InterfaceStability.Unstable +class HiveSessionStateBuilder(session: SparkSession, parentState: Option[SessionState] = None) + extends BaseSessionStateBuilder(session, parentState) { - val plannerCreator = createPlannerCreator( - sparkSession, - initHelper.conf, - initHelper.experimentalMethods) + private def externalCatalog: HiveExternalCatalog = + session.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog] - val hiveSessionState = new HiveSessionState( - sparkContext, - sparkSession.sharedState, - initHelper.conf, - initHelper.experimentalMethods, - initHelper.functionRegistry, - catalog, - initHelper.sqlParser, - metadataHive, - analyzer, - initHelper.streamingQueryManager, - initHelper.queryExecutionCreator, - plannerCreator) - catalog.functionResourceLoader = hiveSessionState.functionResourceLoader - hiveSessionState + /** + * Create a [[HiveSessionCatalog]]. + */ + override protected lazy val catalog: HiveSessionCatalog = { + val catalog = new HiveSessionCatalog( + externalCatalog, + session.sharedState.globalTempViewManager, + new HiveMetastoreCatalog(session), + functionRegistry, + conf, + SessionState.newHadoopConf(session.sparkContext.hadoopConfiguration, conf), + sqlParser, + new SessionFunctionResourceLoader(session)) + parentState.foreach(_.catalog.copyStateTo(catalog)) + catalog } /** - * Create an logical query plan `Analyzer` with rules specific to a `HiveSessionState`. + * A logical query plan `Analyzer` with rules specific to Hive. */ - private def createAnalyzer( - sparkSession: SparkSession, - catalog: HiveSessionCatalog, - sqlConf: SQLConf): Analyzer = { - new Analyzer(catalog, sqlConf) { - override val extendedResolutionRules: Seq[Rule[LogicalPlan]] = - new ResolveHiveSerdeTable(sparkSession) :: - new FindDataSourceTable(sparkSession) :: - new ResolveSQLOnFile(sparkSession) :: Nil - - override val postHocResolutionRules: Seq[Rule[LogicalPlan]] = - new DetermineTableStats(sparkSession) :: - catalog.ParquetConversions :: - catalog.OrcConversions :: - PreprocessTableCreation(sparkSession) :: - PreprocessTableInsertion(sqlConf) :: - DataSourceAnalysis(sqlConf) :: - HiveAnalysis :: Nil + override protected def analyzer: Analyzer = new Analyzer(catalog, conf) { + override val extendedResolutionRules: Seq[Rule[LogicalPlan]] = + new ResolveHiveSerdeTable(session) +: + new FindDataSourceTable(session) +: + new ResolveSQLOnFile(session) +: + customResolutionRules + + override val postHocResolutionRules: Seq[Rule[LogicalPlan]] = + new DetermineTableStats(session) +: + catalog.ParquetConversions +: + catalog.OrcConversions +: + PreprocessTableCreation(session) +: + PreprocessTableInsertion(conf) +: + DataSourceAnalysis(conf) +: + HiveAnalysis +: + customPostHocResolutionRules + + override val extendedCheckRules: Seq[LogicalPlan => Unit] = + PreWriteCheck +: + customCheckRules + } - override val extendedCheckRules = Seq(PreWriteCheck) + /** + * Planner that takes into account Hive-specific strategies. + */ + override protected def planner: SparkPlanner = { + new SparkPlanner(session.sparkContext, conf, experimentalMethods) with HiveStrategies { + override val sparkSession: SparkSession = session + + override def extraPlanningStrategies: Seq[Strategy] = + super.extraPlanningStrategies ++ customPlanningStrategies + + override def strategies: Seq[Strategy] = { + experimentalMethods.extraStrategies ++ + extraPlanningStrategies ++ Seq( + FileSourceStrategy, + DataSourceStrategy, + SpecialLimits, + InMemoryScans, + HiveTableScans, + Scripts, + Aggregation, + JoinSelection, + BasicOperators + ) + } } } - private def createPlannerCreator( - associatedSparkSession: SparkSession, - sqlConf: SQLConf, - experimentalMethods: ExperimentalMethods): () => SparkPlanner = { - () => - new SparkPlanner( - associatedSparkSession.sparkContext, - sqlConf, - experimentalMethods.extraStrategies) - with HiveStrategies { - - override val sparkSession: SparkSession = associatedSparkSession + override protected def newBuilder: NewBuilder = new HiveSessionStateBuilder(_, _) - override def strategies: Seq[Strategy] = { - experimentalMethods.extraStrategies ++ Seq( - FileSourceStrategy, - DataSourceStrategy, - SpecialLimits, - InMemoryScans, - HiveTableScans, - Scripts, - Aggregation, - JoinSelection, - BasicOperators - ) - } - } + override def build(): HiveSessionState = { + val metadataHive: HiveClient = externalCatalog.client.newSession() + new HiveSessionState( + session.sparkContext, + session.sharedState, + conf, + experimentalMethods, + functionRegistry, + catalog, + sqlParser, + analyzer, + optimizer, + planner, + streamingQueryManager, + createQueryExecution, + createClone, + metadataHive) } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index b63ed76967bd..32ca69605ef4 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -40,7 +40,7 @@ import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.execution.command.CacheTableCommand import org.apache.spark.sql.hive._ import org.apache.spark.sql.hive.client.HiveClient -import org.apache.spark.sql.internal.{SessionState, SharedState, SQLConf} +import org.apache.spark.sql.internal._ import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION import org.apache.spark.util.{ShutdownHookManager, Utils} @@ -148,12 +148,14 @@ class TestHiveContext( * * @param sc SparkContext * @param existingSharedState optional [[SharedState]] + * @param parentSessionState optional parent [[SessionState]] * @param loadTestTables if true, load the test tables. They can only be loaded when running * in the JVM, i.e when calling from Python this flag has to be false. */ private[hive] class TestHiveSparkSession( @transient private val sc: SparkContext, @transient private val existingSharedState: Option[TestHiveSharedState], + @transient private val parentSessionState: Option[HiveSessionState], private val loadTestTables: Boolean) extends SparkSession(sc) with Logging { self => @@ -161,6 +163,7 @@ private[hive] class TestHiveSparkSession( this( sc, existingSharedState = None, + parentSessionState = None, loadTestTables) } @@ -168,6 +171,7 @@ private[hive] class TestHiveSparkSession( this( sc, existingSharedState = Some(new TestHiveSharedState(sc, Some(hiveClient))), + parentSessionState = None, loadTestTables) } @@ -192,36 +196,21 @@ private[hive] class TestHiveSparkSession( @transient override lazy val sessionState: HiveSessionState = { - val testConf = - new SQLConf { - clear() - override def caseSensitiveAnalysis: Boolean = getConf(SQLConf.CASE_SENSITIVE, false) - override def clear(): Unit = { - super.clear() - TestHiveContext.overrideConfs.foreach { case (k, v) => setConfString(k, v) } - } - } - val queryExecutionCreator = (plan: LogicalPlan) => new TestHiveQueryExecution(this, plan) - val initHelper = HiveSessionState(this, testConf) - SessionState.mergeSparkConf(testConf, sparkContext.getConf) - - new HiveSessionState( - sparkContext, - sharedState, - testConf, - initHelper.experimentalMethods, - initHelper.functionRegistry, - initHelper.catalog, - initHelper.sqlParser, - initHelper.metadataHive, - initHelper.analyzer, - initHelper.streamingQueryManager, - queryExecutionCreator, - initHelper.plannerCreator) + new TestHiveSessionStateBuilder(this, parentSessionState).build() } override def newSession(): TestHiveSparkSession = { - new TestHiveSparkSession(sc, Some(sharedState), loadTestTables) + new TestHiveSparkSession(sc, Some(sharedState), None, loadTestTables) + } + + override def cloneSession(): SparkSession = { + val result = new TestHiveSparkSession( + sparkContext, + Some(sharedState), + Some(sessionState), + loadTestTables) + result.sessionState // force copy of SessionState + result } private var cacheTables: Boolean = false @@ -595,3 +584,18 @@ private[hive] object TestHiveContext { } } + +private[sql] class TestHiveSessionStateBuilder( + session: SparkSession, + state: Option[SessionState]) + extends HiveSessionStateBuilder(session, state) + with WithTestConf { + + override def overrideConfs: Map[String, String] = TestHiveContext.overrideConfs + + override def createQueryExecution: (LogicalPlan) => QueryExecution = { plan => + new TestHiveQueryExecution(session.asInstanceOf[TestHiveSparkSession], plan) + } + + override protected def newBuilder: NewBuilder = new TestHiveSessionStateBuilder(_, _) +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSessionCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSessionCatalogSuite.scala deleted file mode 100644 index 3b0f59b15916..000000000000 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSessionCatalogSuite.scala +++ /dev/null @@ -1,112 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.hive - -import java.net.URI - -import org.apache.hadoop.conf.Configuration - -import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.analysis.SimpleFunctionRegistry -import org.apache.spark.sql.catalyst.catalog.CatalogDatabase -import org.apache.spark.sql.catalyst.parser.CatalystSqlParser -import org.apache.spark.sql.catalyst.plans.logical.Range -import org.apache.spark.sql.hive.test.TestHiveSingleton -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.util.Utils - -class HiveSessionCatalogSuite extends TestHiveSingleton { - - test("clone HiveSessionCatalog") { - val original = spark.sessionState.catalog.asInstanceOf[HiveSessionCatalog] - - val tempTableName1 = "copytest1" - val tempTableName2 = "copytest2" - try { - val tempTable1 = Range(1, 10, 1, 10) - original.createTempView(tempTableName1, tempTable1, overrideIfExists = false) - - // check if tables copied over - val clone = original.newSessionCatalogWith( - spark, - new SQLConf, - new Configuration(), - new SimpleFunctionRegistry, - CatalystSqlParser) - assert(original ne clone) - assert(clone.getTempView(tempTableName1) == Some(tempTable1)) - - // check if clone and original independent - clone.dropTable(TableIdentifier(tempTableName1), ignoreIfNotExists = false, purge = false) - assert(original.getTempView(tempTableName1) == Some(tempTable1)) - - val tempTable2 = Range(1, 20, 2, 10) - original.createTempView(tempTableName2, tempTable2, overrideIfExists = false) - assert(clone.getTempView(tempTableName2).isEmpty) - } finally { - // Drop the created temp views from the global singleton HiveSession. - original.dropTable(TableIdentifier(tempTableName1), ignoreIfNotExists = true, purge = true) - original.dropTable(TableIdentifier(tempTableName2), ignoreIfNotExists = true, purge = true) - } - } - - test("clone SessionCatalog - current db") { - val original = spark.sessionState.catalog.asInstanceOf[HiveSessionCatalog] - val originalCurrentDatabase = original.getCurrentDatabase - val db1 = "db1" - val db2 = "db2" - val db3 = "db3" - try { - original.createDatabase(newDb(db1), ignoreIfExists = true) - original.createDatabase(newDb(db2), ignoreIfExists = true) - original.createDatabase(newDb(db3), ignoreIfExists = true) - - original.setCurrentDatabase(db1) - - // check if tables copied over - val clone = original.newSessionCatalogWith( - spark, - new SQLConf, - new Configuration(), - new SimpleFunctionRegistry, - CatalystSqlParser) - - // check if current db copied over - assert(original ne clone) - assert(clone.getCurrentDatabase == db1) - - // check if clone and original independent - clone.setCurrentDatabase(db2) - assert(original.getCurrentDatabase == db1) - original.setCurrentDatabase(db3) - assert(clone.getCurrentDatabase == db2) - } finally { - // Drop the created databases from the global singleton HiveSession. - original.dropDatabase(db1, ignoreIfNotExists = true, cascade = true) - original.dropDatabase(db2, ignoreIfNotExists = true, cascade = true) - original.dropDatabase(db3, ignoreIfNotExists = true, cascade = true) - original.setCurrentDatabase(originalCurrentDatabase) - } - } - - def newUriForDatabase(): URI = new URI(Utils.createTempDir().toURI.toString.stripSuffix("/")) - - def newDb(name: String): CatalogDatabase = { - CatalogDatabase(name, name + " description", newUriForDatabase(), Map.empty) - } -} From 6c70a38c2e60e1b69a310aee1a92ee0b3815c02d Mon Sep 17 00:00:00 2001 From: Michal Senkyr Date: Tue, 28 Mar 2017 10:09:49 +0800 Subject: [PATCH 147/512] [SPARK-19088][SQL] Optimize sequence type deserialization codegen ## What changes were proposed in this pull request? Optimization of arbitrary Scala sequence deserialization introduced by #16240. The previous implementation constructed an array which was then converted by `to`. This required two passes in most cases. This implementation attempts to remedy that by using `Builder`s provided by the `newBuilder` method on every Scala collection's companion object to build the resulting collection directly. Example codegen for simple `List` (obtained using `Seq(List(1)).toDS().map(identity).queryExecution.debug.codegen`): Before: ``` /* 001 */ public Object generate(Object[] references) { /* 002 */ return new GeneratedIterator(references); /* 003 */ } /* 004 */ /* 005 */ final class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator { /* 006 */ private Object[] references; /* 007 */ private scala.collection.Iterator[] inputs; /* 008 */ private scala.collection.Iterator inputadapter_input; /* 009 */ private boolean deserializetoobject_resultIsNull; /* 010 */ private java.lang.Object[] deserializetoobject_argValue; /* 011 */ private boolean MapObjects_loopIsNull1; /* 012 */ private int MapObjects_loopValue0; /* 013 */ private boolean deserializetoobject_resultIsNull1; /* 014 */ private scala.collection.generic.CanBuildFrom deserializetoobject_argValue1; /* 015 */ private UnsafeRow deserializetoobject_result; /* 016 */ private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder deserializetoobject_holder; /* 017 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter deserializetoobject_rowWriter; /* 018 */ private scala.collection.immutable.List mapelements_argValue; /* 019 */ private UnsafeRow mapelements_result; /* 020 */ private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder mapelements_holder; /* 021 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter mapelements_rowWriter; /* 022 */ private scala.collection.immutable.List serializefromobject_argValue; /* 023 */ private UnsafeRow serializefromobject_result; /* 024 */ private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder serializefromobject_holder; /* 025 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter serializefromobject_rowWriter; /* 026 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeArrayWriter serializefromobject_arrayWriter; /* 027 */ /* 028 */ public GeneratedIterator(Object[] references) { /* 029 */ this.references = references; /* 030 */ } /* 031 */ /* 032 */ public void init(int index, scala.collection.Iterator[] inputs) { /* 033 */ partitionIndex = index; /* 034 */ this.inputs = inputs; /* 035 */ inputadapter_input = inputs[0]; /* 036 */ /* 037 */ deserializetoobject_result = new UnsafeRow(1); /* 038 */ this.deserializetoobject_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(deserializetoobject_result, 32); /* 039 */ this.deserializetoobject_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(deserializetoobject_holder, 1); /* 040 */ /* 041 */ mapelements_result = new UnsafeRow(1); /* 042 */ this.mapelements_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(mapelements_result, 32); /* 043 */ this.mapelements_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(mapelements_holder, 1); /* 044 */ /* 045 */ serializefromobject_result = new UnsafeRow(1); /* 046 */ this.serializefromobject_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(serializefromobject_result, 32); /* 047 */ this.serializefromobject_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(serializefromobject_holder, 1); /* 048 */ this.serializefromobject_arrayWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeArrayWriter(); /* 049 */ /* 050 */ } /* 051 */ /* 052 */ protected void processNext() throws java.io.IOException { /* 053 */ while (inputadapter_input.hasNext() && !stopEarly()) { /* 054 */ InternalRow inputadapter_row = (InternalRow) inputadapter_input.next(); /* 055 */ ArrayData inputadapter_value = inputadapter_row.getArray(0); /* 056 */ /* 057 */ deserializetoobject_resultIsNull = false; /* 058 */ /* 059 */ if (!deserializetoobject_resultIsNull) { /* 060 */ ArrayData deserializetoobject_value3 = null; /* 061 */ /* 062 */ if (!false) { /* 063 */ Integer[] deserializetoobject_convertedArray = null; /* 064 */ int deserializetoobject_dataLength = inputadapter_value.numElements(); /* 065 */ deserializetoobject_convertedArray = new Integer[deserializetoobject_dataLength]; /* 066 */ /* 067 */ int deserializetoobject_loopIndex = 0; /* 068 */ while (deserializetoobject_loopIndex < deserializetoobject_dataLength) { /* 069 */ MapObjects_loopValue0 = (int) (inputadapter_value.getInt(deserializetoobject_loopIndex)); /* 070 */ MapObjects_loopIsNull1 = inputadapter_value.isNullAt(deserializetoobject_loopIndex); /* 071 */ /* 072 */ if (MapObjects_loopIsNull1) { /* 073 */ throw new RuntimeException(((java.lang.String) references[0])); /* 074 */ } /* 075 */ if (false) { /* 076 */ deserializetoobject_convertedArray[deserializetoobject_loopIndex] = null; /* 077 */ } else { /* 078 */ deserializetoobject_convertedArray[deserializetoobject_loopIndex] = MapObjects_loopValue0; /* 079 */ } /* 080 */ /* 081 */ deserializetoobject_loopIndex += 1; /* 082 */ } /* 083 */ /* 084 */ deserializetoobject_value3 = new org.apache.spark.sql.catalyst.util.GenericArrayData(deserializetoobject_convertedArray); /* 085 */ } /* 086 */ boolean deserializetoobject_isNull2 = true; /* 087 */ java.lang.Object[] deserializetoobject_value2 = null; /* 088 */ if (!false) { /* 089 */ deserializetoobject_isNull2 = false; /* 090 */ if (!deserializetoobject_isNull2) { /* 091 */ Object deserializetoobject_funcResult = null; /* 092 */ deserializetoobject_funcResult = deserializetoobject_value3.array(); /* 093 */ if (deserializetoobject_funcResult == null) { /* 094 */ deserializetoobject_isNull2 = true; /* 095 */ } else { /* 096 */ deserializetoobject_value2 = (java.lang.Object[]) deserializetoobject_funcResult; /* 097 */ } /* 098 */ /* 099 */ } /* 100 */ deserializetoobject_isNull2 = deserializetoobject_value2 == null; /* 101 */ } /* 102 */ deserializetoobject_resultIsNull = deserializetoobject_isNull2; /* 103 */ deserializetoobject_argValue = deserializetoobject_value2; /* 104 */ } /* 105 */ /* 106 */ boolean deserializetoobject_isNull1 = deserializetoobject_resultIsNull; /* 107 */ final scala.collection.Seq deserializetoobject_value1 = deserializetoobject_resultIsNull ? null : scala.collection.mutable.WrappedArray.make(deserializetoobject_argValue); /* 108 */ deserializetoobject_isNull1 = deserializetoobject_value1 == null; /* 109 */ boolean deserializetoobject_isNull = true; /* 110 */ scala.collection.immutable.List deserializetoobject_value = null; /* 111 */ if (!deserializetoobject_isNull1) { /* 112 */ deserializetoobject_resultIsNull1 = false; /* 113 */ /* 114 */ if (!deserializetoobject_resultIsNull1) { /* 115 */ boolean deserializetoobject_isNull6 = false; /* 116 */ final scala.collection.generic.CanBuildFrom deserializetoobject_value6 = false ? null : scala.collection.immutable.List.canBuildFrom(); /* 117 */ deserializetoobject_isNull6 = deserializetoobject_value6 == null; /* 118 */ deserializetoobject_resultIsNull1 = deserializetoobject_isNull6; /* 119 */ deserializetoobject_argValue1 = deserializetoobject_value6; /* 120 */ } /* 121 */ /* 122 */ deserializetoobject_isNull = deserializetoobject_resultIsNull1; /* 123 */ if (!deserializetoobject_isNull) { /* 124 */ Object deserializetoobject_funcResult1 = null; /* 125 */ deserializetoobject_funcResult1 = deserializetoobject_value1.to(deserializetoobject_argValue1); /* 126 */ if (deserializetoobject_funcResult1 == null) { /* 127 */ deserializetoobject_isNull = true; /* 128 */ } else { /* 129 */ deserializetoobject_value = (scala.collection.immutable.List) deserializetoobject_funcResult1; /* 130 */ } /* 131 */ /* 132 */ } /* 133 */ deserializetoobject_isNull = deserializetoobject_value == null; /* 134 */ } /* 135 */ /* 136 */ boolean mapelements_isNull = true; /* 137 */ scala.collection.immutable.List mapelements_value = null; /* 138 */ if (!false) { /* 139 */ mapelements_argValue = deserializetoobject_value; /* 140 */ /* 141 */ mapelements_isNull = false; /* 142 */ if (!mapelements_isNull) { /* 143 */ Object mapelements_funcResult = null; /* 144 */ mapelements_funcResult = ((scala.Function1) references[1]).apply(mapelements_argValue); /* 145 */ if (mapelements_funcResult == null) { /* 146 */ mapelements_isNull = true; /* 147 */ } else { /* 148 */ mapelements_value = (scala.collection.immutable.List) mapelements_funcResult; /* 149 */ } /* 150 */ /* 151 */ } /* 152 */ mapelements_isNull = mapelements_value == null; /* 153 */ } /* 154 */ /* 155 */ if (mapelements_isNull) { /* 156 */ throw new RuntimeException(((java.lang.String) references[2])); /* 157 */ } /* 158 */ serializefromobject_argValue = mapelements_value; /* 159 */ /* 160 */ final ArrayData serializefromobject_value = false ? null : new org.apache.spark.sql.catalyst.util.GenericArrayData(serializefromobject_argValue); /* 161 */ serializefromobject_holder.reset(); /* 162 */ /* 163 */ // Remember the current cursor so that we can calculate how many bytes are /* 164 */ // written later. /* 165 */ final int serializefromobject_tmpCursor = serializefromobject_holder.cursor; /* 166 */ /* 167 */ if (serializefromobject_value instanceof UnsafeArrayData) { /* 168 */ final int serializefromobject_sizeInBytes = ((UnsafeArrayData) serializefromobject_value).getSizeInBytes(); /* 169 */ // grow the global buffer before writing data. /* 170 */ serializefromobject_holder.grow(serializefromobject_sizeInBytes); /* 171 */ ((UnsafeArrayData) serializefromobject_value).writeToMemory(serializefromobject_holder.buffer, serializefromobject_holder.cursor); /* 172 */ serializefromobject_holder.cursor += serializefromobject_sizeInBytes; /* 173 */ /* 174 */ } else { /* 175 */ final int serializefromobject_numElements = serializefromobject_value.numElements(); /* 176 */ serializefromobject_arrayWriter.initialize(serializefromobject_holder, serializefromobject_numElements, 4); /* 177 */ /* 178 */ for (int serializefromobject_index = 0; serializefromobject_index < serializefromobject_numElements; serializefromobject_index++) { /* 179 */ if (serializefromobject_value.isNullAt(serializefromobject_index)) { /* 180 */ serializefromobject_arrayWriter.setNullInt(serializefromobject_index); /* 181 */ } else { /* 182 */ final int serializefromobject_element = serializefromobject_value.getInt(serializefromobject_index); /* 183 */ serializefromobject_arrayWriter.write(serializefromobject_index, serializefromobject_element); /* 184 */ } /* 185 */ } /* 186 */ } /* 187 */ /* 188 */ serializefromobject_rowWriter.setOffsetAndSize(0, serializefromobject_tmpCursor, serializefromobject_holder.cursor - serializefromobject_tmpCursor); /* 189 */ serializefromobject_result.setTotalSize(serializefromobject_holder.totalSize()); /* 190 */ append(serializefromobject_result); /* 191 */ if (shouldStop()) return; /* 192 */ } /* 193 */ } /* 194 */ } ``` After: ``` /* 001 */ public Object generate(Object[] references) { /* 002 */ return new GeneratedIterator(references); /* 003 */ } /* 004 */ /* 005 */ final class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator { /* 006 */ private Object[] references; /* 007 */ private scala.collection.Iterator[] inputs; /* 008 */ private scala.collection.Iterator inputadapter_input; /* 009 */ private boolean CollectObjects_loopIsNull1; /* 010 */ private int CollectObjects_loopValue0; /* 011 */ private UnsafeRow deserializetoobject_result; /* 012 */ private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder deserializetoobject_holder; /* 013 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter deserializetoobject_rowWriter; /* 014 */ private scala.collection.immutable.List mapelements_argValue; /* 015 */ private UnsafeRow mapelements_result; /* 016 */ private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder mapelements_holder; /* 017 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter mapelements_rowWriter; /* 018 */ private scala.collection.immutable.List serializefromobject_argValue; /* 019 */ private UnsafeRow serializefromobject_result; /* 020 */ private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder serializefromobject_holder; /* 021 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter serializefromobject_rowWriter; /* 022 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeArrayWriter serializefromobject_arrayWriter; /* 023 */ /* 024 */ public GeneratedIterator(Object[] references) { /* 025 */ this.references = references; /* 026 */ } /* 027 */ /* 028 */ public void init(int index, scala.collection.Iterator[] inputs) { /* 029 */ partitionIndex = index; /* 030 */ this.inputs = inputs; /* 031 */ inputadapter_input = inputs[0]; /* 032 */ /* 033 */ deserializetoobject_result = new UnsafeRow(1); /* 034 */ this.deserializetoobject_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(deserializetoobject_result, 32); /* 035 */ this.deserializetoobject_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(deserializetoobject_holder, 1); /* 036 */ /* 037 */ mapelements_result = new UnsafeRow(1); /* 038 */ this.mapelements_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(mapelements_result, 32); /* 039 */ this.mapelements_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(mapelements_holder, 1); /* 040 */ /* 041 */ serializefromobject_result = new UnsafeRow(1); /* 042 */ this.serializefromobject_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(serializefromobject_result, 32); /* 043 */ this.serializefromobject_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(serializefromobject_holder, 1); /* 044 */ this.serializefromobject_arrayWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeArrayWriter(); /* 045 */ /* 046 */ } /* 047 */ /* 048 */ protected void processNext() throws java.io.IOException { /* 049 */ while (inputadapter_input.hasNext() && !stopEarly()) { /* 050 */ InternalRow inputadapter_row = (InternalRow) inputadapter_input.next(); /* 051 */ ArrayData inputadapter_value = inputadapter_row.getArray(0); /* 052 */ /* 053 */ scala.collection.immutable.List deserializetoobject_value = null; /* 054 */ /* 055 */ if (!false) { /* 056 */ int deserializetoobject_dataLength = inputadapter_value.numElements(); /* 057 */ scala.collection.mutable.Builder CollectObjects_builderValue2 = scala.collection.immutable.List$.MODULE$.newBuilder(); /* 058 */ CollectObjects_builderValue2.sizeHint(deserializetoobject_dataLength); /* 059 */ /* 060 */ int deserializetoobject_loopIndex = 0; /* 061 */ while (deserializetoobject_loopIndex < deserializetoobject_dataLength) { /* 062 */ CollectObjects_loopValue0 = (int) (inputadapter_value.getInt(deserializetoobject_loopIndex)); /* 063 */ CollectObjects_loopIsNull1 = inputadapter_value.isNullAt(deserializetoobject_loopIndex); /* 064 */ /* 065 */ if (CollectObjects_loopIsNull1) { /* 066 */ throw new RuntimeException(((java.lang.String) references[0])); /* 067 */ } /* 068 */ if (false) { /* 069 */ CollectObjects_builderValue2.$plus$eq(null); /* 070 */ } else { /* 071 */ CollectObjects_builderValue2.$plus$eq(CollectObjects_loopValue0); /* 072 */ } /* 073 */ /* 074 */ deserializetoobject_loopIndex += 1; /* 075 */ } /* 076 */ /* 077 */ deserializetoobject_value = (scala.collection.immutable.List) CollectObjects_builderValue2.result(); /* 078 */ } /* 079 */ /* 080 */ boolean mapelements_isNull = true; /* 081 */ scala.collection.immutable.List mapelements_value = null; /* 082 */ if (!false) { /* 083 */ mapelements_argValue = deserializetoobject_value; /* 084 */ /* 085 */ mapelements_isNull = false; /* 086 */ if (!mapelements_isNull) { /* 087 */ Object mapelements_funcResult = null; /* 088 */ mapelements_funcResult = ((scala.Function1) references[1]).apply(mapelements_argValue); /* 089 */ if (mapelements_funcResult == null) { /* 090 */ mapelements_isNull = true; /* 091 */ } else { /* 092 */ mapelements_value = (scala.collection.immutable.List) mapelements_funcResult; /* 093 */ } /* 094 */ /* 095 */ } /* 096 */ mapelements_isNull = mapelements_value == null; /* 097 */ } /* 098 */ /* 099 */ if (mapelements_isNull) { /* 100 */ throw new RuntimeException(((java.lang.String) references[2])); /* 101 */ } /* 102 */ serializefromobject_argValue = mapelements_value; /* 103 */ /* 104 */ final ArrayData serializefromobject_value = false ? null : new org.apache.spark.sql.catalyst.util.GenericArrayData(serializefromobject_argValue); /* 105 */ serializefromobject_holder.reset(); /* 106 */ /* 107 */ // Remember the current cursor so that we can calculate how many bytes are /* 108 */ // written later. /* 109 */ final int serializefromobject_tmpCursor = serializefromobject_holder.cursor; /* 110 */ /* 111 */ if (serializefromobject_value instanceof UnsafeArrayData) { /* 112 */ final int serializefromobject_sizeInBytes = ((UnsafeArrayData) serializefromobject_value).getSizeInBytes(); /* 113 */ // grow the global buffer before writing data. /* 114 */ serializefromobject_holder.grow(serializefromobject_sizeInBytes); /* 115 */ ((UnsafeArrayData) serializefromobject_value).writeToMemory(serializefromobject_holder.buffer, serializefromobject_holder.cursor); /* 116 */ serializefromobject_holder.cursor += serializefromobject_sizeInBytes; /* 117 */ /* 118 */ } else { /* 119 */ final int serializefromobject_numElements = serializefromobject_value.numElements(); /* 120 */ serializefromobject_arrayWriter.initialize(serializefromobject_holder, serializefromobject_numElements, 4); /* 121 */ /* 122 */ for (int serializefromobject_index = 0; serializefromobject_index < serializefromobject_numElements; serializefromobject_index++) { /* 123 */ if (serializefromobject_value.isNullAt(serializefromobject_index)) { /* 124 */ serializefromobject_arrayWriter.setNullInt(serializefromobject_index); /* 125 */ } else { /* 126 */ final int serializefromobject_element = serializefromobject_value.getInt(serializefromobject_index); /* 127 */ serializefromobject_arrayWriter.write(serializefromobject_index, serializefromobject_element); /* 128 */ } /* 129 */ } /* 130 */ } /* 131 */ /* 132 */ serializefromobject_rowWriter.setOffsetAndSize(0, serializefromobject_tmpCursor, serializefromobject_holder.cursor - serializefromobject_tmpCursor); /* 133 */ serializefromobject_result.setTotalSize(serializefromobject_holder.totalSize()); /* 134 */ append(serializefromobject_result); /* 135 */ if (shouldStop()) return; /* 136 */ } /* 137 */ } /* 138 */ } ``` Benchmark results before: ``` OpenJDK 64-Bit Server VM 1.8.0_112-b15 on Linux 4.8.13-1-ARCH AMD A10-4600M APU with Radeon(tm) HD Graphics collect: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ Seq 269 / 370 0.0 269125.8 1.0X List 154 / 176 0.0 154453.5 1.7X mutable.Queue 210 / 233 0.0 209691.6 1.3X ``` Benchmark results after: ``` OpenJDK 64-Bit Server VM 1.8.0_112-b15 on Linux 4.8.13-1-ARCH AMD A10-4600M APU with Radeon(tm) HD Graphics collect: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ Seq 255 / 316 0.0 254697.3 1.0X List 152 / 177 0.0 152410.0 1.7X mutable.Queue 213 / 235 0.0 213470.0 1.2X ``` ## How was this patch tested? ```bash ./build/mvn -DskipTests clean package && ./dev/run-tests ``` Additionally in Spark Shell: ```scala case class QueueClass(q: scala.collection.immutable.Queue[Int]) spark.createDataset(Seq(List(1,2,3))).map(x => QueueClass(scala.collection.immutable.Queue(x: _*))).map(_.q.dequeue).collect ``` Author: Michal Senkyr Closes #16541 from michalsenkyr/dataset-seq-builder. --- .../spark/sql/catalyst/ScalaReflection.scala | 51 ++------------- .../expressions/objects/objects.scala | 64 +++++++++++++++---- .../sql/catalyst/ScalaReflectionSuite.scala | 8 --- 3 files changed, 54 insertions(+), 69 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index c4af284f73d1..1c7720afe1ca 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -307,54 +307,11 @@ object ScalaReflection extends ScalaReflection { } } - val array = Invoke( - MapObjects(mapFunction, getPath, dataType), - "array", - ObjectType(classOf[Array[Any]])) - - val wrappedArray = StaticInvoke( - scala.collection.mutable.WrappedArray.getClass, - ObjectType(classOf[Seq[_]]), - "make", - array :: Nil) - - if (localTypeOf[scala.collection.mutable.WrappedArray[_]] <:< t.erasure) { - wrappedArray - } else { - // Convert to another type using `to` - val cls = mirror.runtimeClass(t.typeSymbol.asClass) - import scala.collection.generic.CanBuildFrom - import scala.reflect.ClassTag - - // Some canBuildFrom methods take an implicit ClassTag parameter - val cbfParams = try { - cls.getDeclaredMethod("canBuildFrom", classOf[ClassTag[_]]) - StaticInvoke( - ClassTag.getClass, - ObjectType(classOf[ClassTag[_]]), - "apply", - StaticInvoke( - cls, - ObjectType(classOf[Class[_]]), - "getClass" - ) :: Nil - ) :: Nil - } catch { - case _: NoSuchMethodException => Nil - } - - Invoke( - wrappedArray, - "to", - ObjectType(cls), - StaticInvoke( - cls, - ObjectType(classOf[CanBuildFrom[_, _, _]]), - "canBuildFrom", - cbfParams - ) :: Nil - ) + val cls = t.dealias.companion.decl(TermName("newBuilder")) match { + case NoSymbol => classOf[Seq[_]] + case _ => mirror.runtimeClass(t.typeSymbol.asClass) } + MapObjects(mapFunction, getPath, dataType, Some(cls)) case t if t <:< localTypeOf[Map[_, _]] => // TODO: add walked type path for map diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 771ac28e5107..bb584f7d087e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions.objects import java.lang.reflect.Modifier +import scala.collection.mutable.Builder import scala.language.existentials import scala.reflect.ClassTag @@ -429,24 +430,34 @@ object MapObjects { * @param function The function applied on the collection elements. * @param inputData An expression that when evaluated returns a collection object. * @param elementType The data type of elements in the collection. + * @param customCollectionCls Class of the resulting collection (returning ObjectType) + * or None (returning ArrayType) */ def apply( function: Expression => Expression, inputData: Expression, - elementType: DataType): MapObjects = { - val loopValue = "MapObjects_loopValue" + curId.getAndIncrement() - val loopIsNull = "MapObjects_loopIsNull" + curId.getAndIncrement() + elementType: DataType, + customCollectionCls: Option[Class[_]] = None): MapObjects = { + val id = curId.getAndIncrement() + val loopValue = s"MapObjects_loopValue$id" + val loopIsNull = s"MapObjects_loopIsNull$id" val loopVar = LambdaVariable(loopValue, loopIsNull, elementType) - MapObjects(loopValue, loopIsNull, elementType, function(loopVar), inputData) + val builderValue = s"MapObjects_builderValue$id" + MapObjects(loopValue, loopIsNull, elementType, function(loopVar), inputData, + customCollectionCls, builderValue) } } /** * Applies the given expression to every element of a collection of items, returning the result - * as an ArrayType. This is similar to a typical map operation, but where the lambda function - * is expressed using catalyst expressions. + * as an ArrayType or ObjectType. This is similar to a typical map operation, but where the lambda + * function is expressed using catalyst expressions. + * + * The type of the result is determined as follows: + * - ArrayType - when customCollectionCls is None + * - ObjectType(collection) - when customCollectionCls contains a collection class * - * The following collection ObjectTypes are currently supported: + * The following collection ObjectTypes are currently supported on input: * Seq, Array, ArrayData, java.util.List * * @param loopValue the name of the loop variable that used when iterate the collection, and used @@ -458,13 +469,19 @@ object MapObjects { * @param lambdaFunction A function that take the `loopVar` as input, and used as lambda function * to handle collection elements. * @param inputData An expression that when evaluated returns a collection object. + * @param customCollectionCls Class of the resulting collection (returning ObjectType) + * or None (returning ArrayType) + * @param builderValue The name of the builder variable used to construct the resulting collection + * (used only when returning ObjectType) */ case class MapObjects private( loopValue: String, loopIsNull: String, loopVarDataType: DataType, lambdaFunction: Expression, - inputData: Expression) extends Expression with NonSQLExpression { + inputData: Expression, + customCollectionCls: Option[Class[_]], + builderValue: String) extends Expression with NonSQLExpression { override def nullable: Boolean = inputData.nullable @@ -474,7 +491,8 @@ case class MapObjects private( throw new UnsupportedOperationException("Only code-generated evaluation is supported") override def dataType: DataType = - ArrayType(lambdaFunction.dataType, containsNull = lambdaFunction.nullable) + customCollectionCls.map(ObjectType.apply).getOrElse( + ArrayType(lambdaFunction.dataType, containsNull = lambdaFunction.nullable)) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val elementJavaType = ctx.javaType(loopVarDataType) @@ -557,15 +575,33 @@ case class MapObjects private( case _ => s"$loopIsNull = $loopValue == null;" } + val (initCollection, addElement, getResult): (String, String => String, String) = + customCollectionCls match { + case Some(cls) => + // collection + val collObjectName = s"${cls.getName}$$.MODULE$$" + val getBuilderVar = s"$collObjectName.newBuilder()" + + (s"""${classOf[Builder[_, _]].getName} $builderValue = $getBuilderVar; + $builderValue.sizeHint($dataLength);""", + genValue => s"$builderValue.$$plus$$eq($genValue);", + s"(${cls.getName}) $builderValue.result();") + case None => + // array + (s"""$convertedType[] $convertedArray = null; + $convertedArray = $arrayConstructor;""", + genValue => s"$convertedArray[$loopIndex] = $genValue;", + s"new ${classOf[GenericArrayData].getName}($convertedArray);") + } + val code = s""" ${genInputData.code} ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; if (!${genInputData.isNull}) { $determineCollectionType - $convertedType[] $convertedArray = null; int $dataLength = $getLength; - $convertedArray = $arrayConstructor; + $initCollection int $loopIndex = 0; while ($loopIndex < $dataLength) { @@ -574,15 +610,15 @@ case class MapObjects private( ${genFunction.code} if (${genFunction.isNull}) { - $convertedArray[$loopIndex] = null; + ${addElement("null")} } else { - $convertedArray[$loopIndex] = $genFunctionValue; + ${addElement(genFunctionValue)} } $loopIndex += 1; } - ${ev.value} = new ${classOf[GenericArrayData].getName}($convertedArray); + ${ev.value} = $getResult } """ ev.copy(code = code, isNull = genInputData.isNull) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala index 650a35398f3e..70ad064f93eb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -312,14 +312,6 @@ class ScalaReflectionSuite extends SparkFunSuite { ArrayType(IntegerType, containsNull = false)) val arrayBufferDeserializer = deserializerFor[ArrayBuffer[Int]] assert(arrayBufferDeserializer.dataType == ObjectType(classOf[ArrayBuffer[_]])) - - // Check whether conversion is skipped when using WrappedArray[_] supertype - // (would otherwise needlessly add overhead) - import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke - val seqDeserializer = deserializerFor[Seq[Int]] - assert(seqDeserializer.asInstanceOf[StaticInvoke].staticObject == - scala.collection.mutable.WrappedArray.getClass) - assert(seqDeserializer.asInstanceOf[StaticInvoke].functionName == "make") } private val dataTypeForComplexData = dataTypeFor[ComplexData] From a9abff281bcb15fdc91111121c8bcb983a9d91cb Mon Sep 17 00:00:00 2001 From: Xiao Li Date: Tue, 28 Mar 2017 09:37:28 +0200 Subject: [PATCH 148/512] [SPARK-20119][TEST-MAVEN] Fix the test case fail in DataSourceScanExecRedactionSuite ### What changes were proposed in this pull request? Changed the pattern to match the first n characters in the location field so that the string truncation does not affect it. ### How was this patch tested? N/A Author: Xiao Li Closes #17448 from gatorsmile/fixTestCAse. --- .../spark/sql/execution/DataSourceScanExecRedactionSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/DataSourceScanExecRedactionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/DataSourceScanExecRedactionSuite.scala index 986fa878ee29..05a2b2c862c7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/DataSourceScanExecRedactionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/DataSourceScanExecRedactionSuite.scala @@ -31,7 +31,7 @@ class DataSourceScanExecRedactionSuite extends QueryTest with SharedSQLContext { override def beforeAll(): Unit = { sparkConf.set("spark.redaction.string.regex", - "spark-[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}") + "file:/[\\w_]+") super.beforeAll() } From 91559d277f42ee83b79f5d8eb7ba037cf5c108da Mon Sep 17 00:00:00 2001 From: wangzhenhua Date: Tue, 28 Mar 2017 13:43:23 +0200 Subject: [PATCH 149/512] [SPARK-20094][SQL] Preventing push down of IN subquery to Join operator ## What changes were proposed in this pull request? TPCDS q45 fails becuase: `ReorderJoin` collects all predicates and try to put them into join condition when creating ordered join. If a predicate with an IN subquery (`ListQuery`) is in a join condition instead of a filter condition, `RewritePredicateSubquery.rewriteExistentialExpr` would fail to convert the subquery to an `ExistenceJoin`, and thus result in error. We should prevent push down of IN subquery to Join operator. ## How was this patch tested? Add a new test case in `FilterPushdownSuite`. Author: wangzhenhua Closes #17428 from wzhfy/noSubqueryInJoinCond. --- .../sql/catalyst/expressions/predicates.scala | 6 ++++++ .../optimizer/FilterPushdownSuite.scala | 20 +++++++++++++++++++ 2 files changed, 26 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index e5d1a1e2996c..1235204591bb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -90,6 +90,12 @@ trait PredicateHelper { * Returns true iff `expr` could be evaluated as a condition within join. */ protected def canEvaluateWithinJoin(expr: Expression): Boolean = expr match { + case l: ListQuery => + // A ListQuery defines the query which we want to search in an IN subquery expression. + // Currently the only way to evaluate an IN subquery is to convert it to a + // LeftSemi/LeftAnti/ExistenceJoin by `RewritePredicateSubquery` rule. + // It cannot be evaluated as part of a Join operator. + false case e: SubqueryExpression => // non-correlated subquery will be replaced as literal e.children.isEmpty diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index 6feea4060f46..d846786473eb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -836,6 +836,26 @@ class FilterPushdownSuite extends PlanTest { comparePlans(optimized, answer) } + test("SPARK-20094: don't push predicate with IN subquery into join condition") { + val x = testRelation.subquery('x) + val z = testRelation.subquery('z) + val w = testRelation1.subquery('w) + + val queryPlan = x + .join(z) + .where(("x.b".attr === "z.b".attr) && + ("x.a".attr > 1 || "z.c".attr.in(ListQuery(w.select("w.d".attr))))) + .analyze + + val expectedPlan = x + .join(z, Inner, Some("x.b".attr === "z.b".attr)) + .where("x.a".attr > 1 || "z.c".attr.in(ListQuery(w.select("w.d".attr)))) + .analyze + + val optimized = Optimize.execute(queryPlan) + comparePlans(optimized, expectedPlan) + } + test("Window: predicate push down -- basic") { val winExpr = windowExpr(count('b), windowSpec('a :: Nil, 'b.asc :: Nil, UnspecifiedFrame)) From 4fcc214d9eb5e98b2eed3e28cc23b0c511cd9007 Mon Sep 17 00:00:00 2001 From: wangzhenhua Date: Tue, 28 Mar 2017 22:22:38 +0800 Subject: [PATCH 150/512] [SPARK-20124][SQL] Join reorder should keep the same order of final project attributes ## What changes were proposed in this pull request? Join reorder algorithm should keep exactly the same order of output attributes in the top project. For example, if user want to select a, b, c, after reordering, we should output a, b, c in the same order as specified by user, instead of b, a, c or other orders. ## How was this patch tested? A new test case is added in `JoinReorderSuite`. Author: wangzhenhua Closes #17453 from wzhfy/keepOrderInProject. --- .../optimizer/CostBasedJoinReorder.scala | 24 ++++++++++++------- .../catalyst/optimizer/JoinReorderSuite.scala | 13 ++++++++++ .../spark/sql/catalyst/plans/PlanTest.scala | 4 ++-- 3 files changed, 31 insertions(+), 10 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala index fc37720809ba..cbd506465ae6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala @@ -40,10 +40,10 @@ case class CostBasedJoinReorder(conf: SQLConf) extends Rule[LogicalPlan] with Pr val result = plan transformDown { // Start reordering with a joinable item, which is an InnerLike join with conditions. case j @ Join(_, _, _: InnerLike, Some(cond)) => - reorder(j, j.outputSet) + reorder(j, j.output) case p @ Project(projectList, Join(_, _, _: InnerLike, Some(cond))) if projectList.forall(_.isInstanceOf[Attribute]) => - reorder(p, p.outputSet) + reorder(p, p.output) } // After reordering is finished, convert OrderedJoin back to Join result transformDown { @@ -52,7 +52,7 @@ case class CostBasedJoinReorder(conf: SQLConf) extends Rule[LogicalPlan] with Pr } } - private def reorder(plan: LogicalPlan, output: AttributeSet): LogicalPlan = { + private def reorder(plan: LogicalPlan, output: Seq[Attribute]): LogicalPlan = { val (items, conditions) = extractInnerJoins(plan) // TODO: Compute the set of star-joins and use them in the join enumeration // algorithm to prune un-optimal plan choices. @@ -140,7 +140,7 @@ object JoinReorderDP extends PredicateHelper with Logging { conf: SQLConf, items: Seq[LogicalPlan], conditions: Set[Expression], - topOutput: AttributeSet): LogicalPlan = { + output: Seq[Attribute]): LogicalPlan = { val startTime = System.nanoTime() // Level i maintains all found plans for i + 1 items. @@ -152,9 +152,10 @@ object JoinReorderDP extends PredicateHelper with Logging { // Build plans for next levels until the last level has only one plan. This plan contains // all items that can be joined, so there's no need to continue. + val topOutputSet = AttributeSet(output) while (foundPlans.size < items.length && foundPlans.last.size > 1) { // Build plans for the next level. - foundPlans += searchLevel(foundPlans, conf, conditions, topOutput) + foundPlans += searchLevel(foundPlans, conf, conditions, topOutputSet) } val durationInMs = (System.nanoTime() - startTime) / (1000 * 1000) @@ -163,7 +164,14 @@ object JoinReorderDP extends PredicateHelper with Logging { // The last level must have one and only one plan, because all items are joinable. assert(foundPlans.size == items.length && foundPlans.last.size == 1) - foundPlans.last.head._2.plan + foundPlans.last.head._2.plan match { + case p @ Project(projectList, j: Join) if projectList != output => + assert(topOutputSet == p.outputSet) + // Keep the same order of final output attributes. + p.copy(projectList = output) + case finalPlan => + finalPlan + } } /** Find all possible plans at the next level, based on existing levels. */ @@ -254,10 +262,10 @@ object JoinReorderDP extends PredicateHelper with Logging { val collectedJoinConds = joinConds ++ oneJoinPlan.joinConds ++ otherJoinPlan.joinConds val remainingConds = conditions -- collectedJoinConds val neededAttr = AttributeSet(remainingConds.flatMap(_.references)) ++ topOutput - val neededFromNewJoin = newJoin.outputSet.filter(neededAttr.contains) + val neededFromNewJoin = newJoin.output.filter(neededAttr.contains) val newPlan = if ((newJoin.outputSet -- neededFromNewJoin).nonEmpty) { - Project(neededFromNewJoin.toSeq, newJoin) + Project(neededFromNewJoin, newJoin) } else { newJoin } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala index 05b839b0119f..d74008c1b302 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala @@ -198,6 +198,19 @@ class JoinReorderSuite extends PlanTest with StatsEstimationTestBase { assertEqualPlans(originalPlan, bestPlan) } + test("keep the order of attributes in the final output") { + val outputLists = Seq("t1.k-1-2", "t1.v-1-10", "t3.v-1-100").permutations + while (outputLists.hasNext) { + val expectedOrder = outputLists.next().map(nameToAttr) + val expectedPlan = + t1.join(t3, Inner, Some(nameToAttr("t1.v-1-10") === nameToAttr("t3.v-1-100"))) + .join(t2, Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5"))) + .select(expectedOrder: _*) + // The plan should not change after optimization + assertEqualPlans(expectedPlan, expectedPlan) + } + } + private def assertEqualPlans( originalPlan: LogicalPlan, groundTruthBestPlan: LogicalPlan): Unit = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala index 2a9d0570148a..c73dfaf3f8fe 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala @@ -126,8 +126,8 @@ abstract class PlanTest extends SparkFunSuite with PredicateHelper { case (j1: Join, j2: Join) => (sameJoinPlan(j1.left, j2.left) && sameJoinPlan(j1.right, j2.right)) || (sameJoinPlan(j1.left, j2.right) && sameJoinPlan(j1.right, j2.left)) - case _ if plan1.children.nonEmpty && plan2.children.nonEmpty => - (plan1.children, plan2.children).zipped.forall { case (c1, c2) => sameJoinPlan(c1, c2) } + case (p1: Project, p2: Project) => + p1.projectList == p2.projectList && sameJoinPlan(p1.child, p2.child) case _ => plan1 == plan2 } From f82461fc1197f6055d9cf972d82260b178e10a7c Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Tue, 28 Mar 2017 23:14:31 +0800 Subject: [PATCH 151/512] [SPARK-20126][SQL] Remove HiveSessionState ## What changes were proposed in this pull request? Commit https://github.com/apache/spark/commit/ea361165e1ddce4d8aa0242ae3e878d7b39f1de2 moved most of the logic from the SessionState classes into an accompanying builder. This makes the existence of the `HiveSessionState` redundant. This PR removes the `HiveSessionState`. ## How was this patch tested? Existing tests. Author: Herman van Hovell Closes #17457 from hvanhovell/SPARK-20126. --- .../sql/execution/command/resources.scala | 2 +- .../spark/sql/internal/SessionState.scala | 47 +++--- .../sql/internal/sessionStateBuilders.scala | 8 +- .../sql/hive/thriftserver/SparkSQLEnv.scala | 12 +- .../server/SparkSQLOperationManager.scala | 6 +- .../execution/HiveCompatibilitySuite.scala | 2 +- .../apache/spark/sql/hive/HiveContext.scala | 4 - .../spark/sql/hive/HiveMetastoreCatalog.scala | 9 +- .../spark/sql/hive/HiveSessionState.scala | 144 +++--------------- .../apache/spark/sql/hive/test/TestHive.scala | 23 ++- .../sql/hive/HiveMetastoreCatalogSuite.scala | 6 +- .../sql/hive/MetastoreDataSourcesSuite.scala | 7 +- .../sql/hive/execution/HiveDDLSuite.scala | 6 +- .../apache/spark/sql/hive/parquetSuites.scala | 21 ++- 14 files changed, 104 insertions(+), 193 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/resources.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/resources.scala index 20b08946675d..2e859cf1ef25 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/resources.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/resources.scala @@ -37,7 +37,7 @@ case class AddJarCommand(path: String) extends RunnableCommand { } override def run(sparkSession: SparkSession): Seq[Row] = { - sparkSession.sessionState.addJar(path) + sparkSession.sessionState.resourceLoader.addJar(path) Seq(Row(0)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala index b5b0bb0bfc40..c6241d923d7b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala @@ -63,6 +63,7 @@ private[sql] class SessionState( val optimizer: Optimizer, val planner: SparkPlanner, val streamingQueryManager: StreamingQueryManager, + val resourceLoader: SessionResourceLoader, createQueryExecution: LogicalPlan => QueryExecution, createClone: (SparkSession, SessionState) => SessionState) { @@ -106,27 +107,6 @@ private[sql] class SessionState( def refreshTable(tableName: String): Unit = { catalog.refreshTable(sqlParser.parseTableIdentifier(tableName)) } - - /** - * Add a jar path to [[SparkContext]] and the classloader. - * - * Note: this method seems not access any session state, but the subclass `HiveSessionState` needs - * to add the jar to its hive client for the current session. Hence, it still needs to be in - * [[SessionState]]. - */ - def addJar(path: String): Unit = { - sparkContext.addJar(path) - val uri = new Path(path).toUri - val jarURL = if (uri.getScheme == null) { - // `path` is a local file path without a URL scheme - new File(path).toURI.toURL - } else { - // `path` is a URL with a scheme - uri.toURL - } - sharedState.jarClassLoader.addURL(jarURL) - Thread.currentThread().setContextClassLoader(sharedState.jarClassLoader) - } } private[sql] object SessionState { @@ -160,10 +140,10 @@ class SessionStateBuilder( * Session shared [[FunctionResourceLoader]]. */ @InterfaceStability.Unstable -class SessionFunctionResourceLoader(session: SparkSession) extends FunctionResourceLoader { +class SessionResourceLoader(session: SparkSession) extends FunctionResourceLoader { override def loadResource(resource: FunctionResource): Unit = { resource.resourceType match { - case JarResource => session.sessionState.addJar(resource.uri) + case JarResource => addJar(resource.uri) case FileResource => session.sparkContext.addFile(resource.uri) case ArchiveResource => throw new AnalysisException( @@ -171,4 +151,25 @@ class SessionFunctionResourceLoader(session: SparkSession) extends FunctionResou "please use --archives options while calling spark-submit.") } } + + /** + * Add a jar path to [[SparkContext]] and the classloader. + * + * Note: this method seems not access any session state, but the subclass `HiveSessionState` needs + * to add the jar to its hive client for the current session. Hence, it still needs to be in + * [[SessionState]]. + */ + def addJar(path: String): Unit = { + session.sparkContext.addJar(path) + val uri = new Path(path).toUri + val jarURL = if (uri.getScheme == null) { + // `path` is a local file path without a URL scheme + new File(path).toURI.toURL + } else { + // `path` is a URL with a scheme + uri.toURL + } + session.sharedState.jarClassLoader.addURL(jarURL) + Thread.currentThread().setContextClassLoader(session.sharedState.jarClassLoader) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/sessionStateBuilders.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/sessionStateBuilders.scala index 6b5559adb1db..b8f645fdee85 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/sessionStateBuilders.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/sessionStateBuilders.scala @@ -109,6 +109,11 @@ abstract class BaseSessionStateBuilder( */ protected lazy val sqlParser: ParserInterface = new SparkSqlParser(conf) + /** + * ResourceLoader that is used to load function resources and jars. + */ + protected lazy val resourceLoader: SessionResourceLoader = new SessionResourceLoader(session) + /** * Catalog for managing table and database states. If there is a pre-existing catalog, the state * of that catalog (temp tables & current database) will be copied into the new catalog. @@ -123,7 +128,7 @@ abstract class BaseSessionStateBuilder( conf, SessionState.newHadoopConf(session.sparkContext.hadoopConfiguration, conf), sqlParser, - new SessionFunctionResourceLoader(session)) + resourceLoader) parentState.foreach(_.catalog.copyStateTo(catalog)) catalog } @@ -251,6 +256,7 @@ abstract class BaseSessionStateBuilder( optimizer, planner, streamingQueryManager, + resourceLoader, createQueryExecution, createClone) } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala index c0b299411e94..01c4eb131a56 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala @@ -22,7 +22,7 @@ import java.io.PrintStream import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.internal.Logging import org.apache.spark.sql.{SparkSession, SQLContext} -import org.apache.spark.sql.hive.{HiveSessionState, HiveUtils} +import org.apache.spark.sql.hive.{HiveExternalCatalog, HiveUtils} import org.apache.spark.util.Utils /** A singleton object for the master program. The slaves should not access this. */ @@ -49,10 +49,12 @@ private[hive] object SparkSQLEnv extends Logging { sparkContext = sparkSession.sparkContext sqlContext = sparkSession.sqlContext - val sessionState = sparkSession.sessionState.asInstanceOf[HiveSessionState] - sessionState.metadataHive.setOut(new PrintStream(System.out, true, "UTF-8")) - sessionState.metadataHive.setInfo(new PrintStream(System.err, true, "UTF-8")) - sessionState.metadataHive.setError(new PrintStream(System.err, true, "UTF-8")) + val metadataHive = sparkSession + .sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog] + .client.newSession() + metadataHive.setOut(new PrintStream(System.out, true, "UTF-8")) + metadataHive.setInfo(new PrintStream(System.err, true, "UTF-8")) + metadataHive.setError(new PrintStream(System.err, true, "UTF-8")) sparkSession.conf.set("spark.sql.hive.version", HiveUtils.hiveExecutionVersion) } } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala index 49ab66400934..a0e5012633f5 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala @@ -26,7 +26,7 @@ import org.apache.hive.service.cli.session.HiveSession import org.apache.spark.internal.Logging import org.apache.spark.sql.SQLContext -import org.apache.spark.sql.hive.HiveSessionState +import org.apache.spark.sql.hive.HiveUtils import org.apache.spark.sql.hive.thriftserver.{ReflectionUtils, SparkExecuteStatementOperation} /** @@ -49,8 +49,8 @@ private[thriftserver] class SparkSQLOperationManager() val sqlContext = sessionToContexts.get(parentSession.getSessionHandle) require(sqlContext != null, s"Session handle: ${parentSession.getSessionHandle} has not been" + s" initialized or had already closed.") - val sessionState = sqlContext.sessionState.asInstanceOf[HiveSessionState] - val runInBackground = async && sessionState.hiveThriftServerAsync + val conf = sqlContext.sessionState.conf + val runInBackground = async && conf.getConf(HiveUtils.HIVE_THRIFT_SERVER_ASYNC) val operation = new SparkExecuteStatementOperation(parentSession, statement, confOverlay, runInBackground)(sqlContext, sessionToActivePool) handleToOperation.put(operation.getHandle, operation) diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index f78660f7c14b..0a53aaca404e 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -39,7 +39,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { private val originalLocale = Locale.getDefault private val originalColumnBatchSize = TestHive.conf.columnBatchSize private val originalInMemoryPartitionPruning = TestHive.conf.inMemoryPartitionPruning - private val originalConvertMetastoreOrc = TestHive.sessionState.convertMetastoreOrc + private val originalConvertMetastoreOrc = TestHive.conf.getConf(HiveUtils.CONVERT_METASTORE_ORC) private val originalCrossJoinEnabled = TestHive.conf.crossJoinEnabled private val originalSessionLocalTimeZone = TestHive.conf.sessionLocalTimeZone diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 5393c57c9a28..02a5117f005e 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -48,10 +48,6 @@ class HiveContext private[hive](_sparkSession: SparkSession) new HiveContext(sparkSession.newSession()) } - protected[sql] override def sessionState: HiveSessionState = { - sparkSession.sessionState.asInstanceOf[HiveSessionState] - } - /** * Invalidate and refresh all the cached the metadata of the given table. For performance reasons, * Spark SQL or the external data source library it uses might cache certain metadata about a diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 2e060ab9f680..305bd007c93f 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -44,7 +44,7 @@ import org.apache.spark.sql.types._ */ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Logging { // these are def_s and not val/lazy val since the latter would introduce circular references - private def sessionState = sparkSession.sessionState.asInstanceOf[HiveSessionState] + private def sessionState = sparkSession.sessionState private def tableRelationCache = sparkSession.sessionState.catalog.tableRelationCache import HiveMetastoreCatalog._ @@ -281,12 +281,13 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log object ParquetConversions extends Rule[LogicalPlan] { private def shouldConvertMetastoreParquet(relation: CatalogRelation): Boolean = { relation.tableMeta.storage.serde.getOrElse("").toLowerCase.contains("parquet") && - sessionState.convertMetastoreParquet + sessionState.conf.getConf(HiveUtils.CONVERT_METASTORE_PARQUET) } private def convertToParquetRelation(relation: CatalogRelation): LogicalRelation = { val fileFormatClass = classOf[ParquetFileFormat] - val mergeSchema = sessionState.convertMetastoreParquetWithSchemaMerging + val mergeSchema = sessionState.conf.getConf( + HiveUtils.CONVERT_METASTORE_PARQUET_WITH_SCHEMA_MERGING) val options = Map(ParquetOptions.MERGE_SCHEMA -> mergeSchema.toString) convertToLogicalRelation(relation, options, fileFormatClass, "parquet") @@ -316,7 +317,7 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log object OrcConversions extends Rule[LogicalPlan] { private def shouldConvertMetastoreOrc(relation: CatalogRelation): Boolean = { relation.tableMeta.storage.serde.getOrElse("").toLowerCase.contains("orc") && - sessionState.convertMetastoreOrc + sessionState.conf.getConf(HiveUtils.CONVERT_METASTORE_ORC) } private def convertToOrcRelation(relation: CatalogRelation): LogicalRelation = { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala index 49ff8478f1ae..f49e6bb41864 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala @@ -17,121 +17,24 @@ package org.apache.spark.sql.hive -import org.apache.spark.SparkContext import org.apache.spark.annotation.{Experimental, InterfaceStability} import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.analysis.{Analyzer, FunctionRegistry} -import org.apache.spark.sql.catalyst.optimizer.Optimizer -import org.apache.spark.sql.catalyst.parser.ParserInterface +import org.apache.spark.sql.catalyst.analysis.Analyzer import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.execution.{QueryExecution, SparkPlanner} +import org.apache.spark.sql.execution.SparkPlanner import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.hive.client.HiveClient -import org.apache.spark.sql.internal.{BaseSessionStateBuilder, SessionFunctionResourceLoader, SessionState, SharedState, SQLConf} -import org.apache.spark.sql.streaming.StreamingQueryManager - +import org.apache.spark.sql.internal.{BaseSessionStateBuilder, SessionResourceLoader, SessionState} /** - * A class that holds all session-specific state in a given [[SparkSession]] backed by Hive. - * - * @param sparkContext The [[SparkContext]]. - * @param sharedState The shared state. - * @param conf SQL-specific key-value configurations. - * @param experimentalMethods The experimental methods. - * @param functionRegistry Internal catalog for managing functions registered by the user. - * @param catalog Internal catalog for managing table and database states that uses Hive client for - * interacting with the metastore. - * @param sqlParser Parser that extracts expressions, plans, table identifiers etc. from SQL texts. - * @param analyzer Logical query plan analyzer for resolving unresolved attributes and relations. - * @param optimizer Logical query plan optimizer. - * @param planner Planner that converts optimized logical plans to physical plans and that takes - * Hive-specific strategies into account. - * @param streamingQueryManager Interface to start and stop streaming queries. - * @param createQueryExecution Function used to create QueryExecution objects. - * @param createClone Function used to create clones of the session state. - * @param metadataHive The Hive metadata client. + * Entry object for creating a Hive aware [[SessionState]]. */ -private[hive] class HiveSessionState( - sparkContext: SparkContext, - sharedState: SharedState, - conf: SQLConf, - experimentalMethods: ExperimentalMethods, - functionRegistry: FunctionRegistry, - override val catalog: HiveSessionCatalog, - sqlParser: ParserInterface, - analyzer: Analyzer, - optimizer: Optimizer, - planner: SparkPlanner, - streamingQueryManager: StreamingQueryManager, - createQueryExecution: LogicalPlan => QueryExecution, - createClone: (SparkSession, SessionState) => SessionState, - val metadataHive: HiveClient) - extends SessionState( - sparkContext, - sharedState, - conf, - experimentalMethods, - functionRegistry, - catalog, - sqlParser, - analyzer, - optimizer, - planner, - streamingQueryManager, - createQueryExecution, - createClone) { - - // ------------------------------------------------------ - // Helper methods, partially leftover from pre-2.0 days - // ------------------------------------------------------ - - override def addJar(path: String): Unit = { - metadataHive.addJar(path) - super.addJar(path) - } - - /** - * When true, enables an experimental feature where metastore tables that use the parquet SerDe - * are automatically converted to use the Spark SQL parquet table scan, instead of the Hive - * SerDe. - */ - def convertMetastoreParquet: Boolean = { - conf.getConf(HiveUtils.CONVERT_METASTORE_PARQUET) - } - - /** - * When true, also tries to merge possibly different but compatible Parquet schemas in different - * Parquet data files. - * - * This configuration is only effective when "spark.sql.hive.convertMetastoreParquet" is true. - */ - def convertMetastoreParquetWithSchemaMerging: Boolean = { - conf.getConf(HiveUtils.CONVERT_METASTORE_PARQUET_WITH_SCHEMA_MERGING) - } - - /** - * When true, enables an experimental feature where metastore tables that use the Orc SerDe - * are automatically converted to use the Spark SQL ORC table scan, instead of the Hive - * SerDe. - */ - def convertMetastoreOrc: Boolean = { - conf.getConf(HiveUtils.CONVERT_METASTORE_ORC) - } - - /** - * When true, Hive Thrift server will execute SQL queries asynchronously using a thread pool." - */ - def hiveThriftServerAsync: Boolean = { - conf.getConf(HiveUtils.HIVE_THRIFT_SERVER_ASYNC) - } -} - private[hive] object HiveSessionState { /** - * Create a new [[HiveSessionState]] for the given session. + * Create a new Hive aware [[SessionState]]. for the given session. */ - def apply(session: SparkSession): HiveSessionState = { + def apply(session: SparkSession): SessionState = { new HiveSessionStateBuilder(session).build() } } @@ -147,6 +50,14 @@ class HiveSessionStateBuilder(session: SparkSession, parentState: Option[Session private def externalCatalog: HiveExternalCatalog = session.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog] + /** + * Create a Hive aware resource loader. + */ + override protected lazy val resourceLoader: HiveSessionResourceLoader = { + val client: HiveClient = externalCatalog.client.newSession() + new HiveSessionResourceLoader(session, client) + } + /** * Create a [[HiveSessionCatalog]]. */ @@ -159,7 +70,7 @@ class HiveSessionStateBuilder(session: SparkSession, parentState: Option[Session conf, SessionState.newHadoopConf(session.sparkContext.hadoopConfiguration, conf), sqlParser, - new SessionFunctionResourceLoader(session)) + resourceLoader) parentState.foreach(_.catalog.copyStateTo(catalog)) catalog } @@ -217,23 +128,14 @@ class HiveSessionStateBuilder(session: SparkSession, parentState: Option[Session } override protected def newBuilder: NewBuilder = new HiveSessionStateBuilder(_, _) +} - override def build(): HiveSessionState = { - val metadataHive: HiveClient = externalCatalog.client.newSession() - new HiveSessionState( - session.sparkContext, - session.sharedState, - conf, - experimentalMethods, - functionRegistry, - catalog, - sqlParser, - analyzer, - optimizer, - planner, - streamingQueryManager, - createQueryExecution, - createClone, - metadataHive) +class HiveSessionResourceLoader( + session: SparkSession, + client: HiveClient) + extends SessionResourceLoader(session) { + override def addJar(path: String): Unit = { + client.addJar(path) + super.addJar(path) } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index 32ca69605ef4..0bcf21992276 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -34,7 +34,6 @@ import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.internal.Logging import org.apache.spark.sql.{SparkSession, SQLContext} import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation -import org.apache.spark.sql.catalyst.catalog.ExternalCatalog import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.execution.command.CacheTableCommand @@ -81,7 +80,7 @@ private[hive] class TestHiveSharedState( hiveClient: Option[HiveClient] = None) extends SharedState(sc) { - override lazy val externalCatalog: ExternalCatalog = { + override lazy val externalCatalog: TestHiveExternalCatalog = { new TestHiveExternalCatalog( sc.conf, sc.hadoopConfiguration, @@ -123,8 +122,6 @@ class TestHiveContext( new TestHiveContext(sparkSession.newSession()) } - override def sessionState: HiveSessionState = sparkSession.sessionState - def setCacheTables(c: Boolean): Unit = { sparkSession.setCacheTables(c) } @@ -155,7 +152,7 @@ class TestHiveContext( private[hive] class TestHiveSparkSession( @transient private val sc: SparkContext, @transient private val existingSharedState: Option[TestHiveSharedState], - @transient private val parentSessionState: Option[HiveSessionState], + @transient private val parentSessionState: Option[SessionState], private val loadTestTables: Boolean) extends SparkSession(sc) with Logging { self => @@ -195,10 +192,12 @@ private[hive] class TestHiveSparkSession( } @transient - override lazy val sessionState: HiveSessionState = { + override lazy val sessionState: SessionState = { new TestHiveSessionStateBuilder(this, parentSessionState).build() } + lazy val metadataHive: HiveClient = sharedState.externalCatalog.client.newSession() + override def newSession(): TestHiveSparkSession = { new TestHiveSparkSession(sc, Some(sharedState), None, loadTestTables) } @@ -492,7 +491,7 @@ private[hive] class TestHiveSparkSession( sessionState.catalog.clearTempTables() sessionState.catalog.tableRelationCache.invalidateAll() - sessionState.metadataHive.reset() + metadataHive.reset() FunctionRegistry.getFunctionNames.asScala.filterNot(originalUDFs.contains(_)). foreach { udfName => FunctionRegistry.unregisterTemporaryUDF(udfName) } @@ -509,14 +508,14 @@ private[hive] class TestHiveSparkSession( sessionState.conf.setConfString("fs.defaultFS", new File(".").toURI.toString) // It is important that we RESET first as broken hooks that might have been set could break // other sql exec here. - sessionState.metadataHive.runSqlHive("RESET") + metadataHive.runSqlHive("RESET") // For some reason, RESET does not reset the following variables... // https://issues.apache.org/jira/browse/HIVE-9004 - sessionState.metadataHive.runSqlHive("set hive.table.parameters.default=") - sessionState.metadataHive.runSqlHive("set datanucleus.cache.collections=true") - sessionState.metadataHive.runSqlHive("set datanucleus.cache.collections.lazy=true") + metadataHive.runSqlHive("set hive.table.parameters.default=") + metadataHive.runSqlHive("set datanucleus.cache.collections=true") + metadataHive.runSqlHive("set datanucleus.cache.collections.lazy=true") // Lots of tests fail if we do not change the partition whitelist from the default. - sessionState.metadataHive.runSqlHive("set hive.metastore.partition.name.whitelist.pattern=.*") + metadataHive.runSqlHive("set hive.metastore.partition.name.whitelist.pattern=.*") sessionState.catalog.setCurrentDatabase("default") } catch { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala index 079358b29a19..d8fd68b63d1e 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala @@ -115,7 +115,7 @@ class DataSourceWithHiveMetastoreCatalogSuite assert(columns.map(_.dataType) === Seq(DecimalType(10, 3), StringType)) checkAnswer(table("t"), testDF) - assert(sessionState.metadataHive.runSqlHive("SELECT * FROM t") === Seq("1.1\t1", "2.1\t2")) + assert(sparkSession.metadataHive.runSqlHive("SELECT * FROM t") === Seq("1.1\t1", "2.1\t2")) } } @@ -147,7 +147,7 @@ class DataSourceWithHiveMetastoreCatalogSuite assert(columns.map(_.dataType) === Seq(DecimalType(10, 3), StringType)) checkAnswer(table("t"), testDF) - assert(sessionState.metadataHive.runSqlHive("SELECT * FROM t") === + assert(sparkSession.metadataHive.runSqlHive("SELECT * FROM t") === Seq("1.1\t1", "2.1\t2")) } } @@ -176,7 +176,7 @@ class DataSourceWithHiveMetastoreCatalogSuite assert(columns.map(_.dataType) === Seq(IntegerType, StringType)) checkAnswer(table("t"), Row(1, "val_1")) - assert(sessionState.metadataHive.runSqlHive("SELECT * FROM t") === Seq("1\tval_1")) + assert(sparkSession.metadataHive.runSqlHive("SELECT * FROM t") === Seq("1\tval_1")) } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index f02b7218d6ee..55e02acfa4ce 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -379,8 +379,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv |) """.stripMargin) - val expectedPath = - sessionState.catalog.hiveDefaultTableFilePath(TableIdentifier("ctasJsonTable")) + val expectedPath = sessionState.catalog.defaultTablePath(TableIdentifier("ctasJsonTable")) val filesystemPath = new Path(expectedPath) val fs = filesystemPath.getFileSystem(spark.sessionState.newHadoopConf()) fs.delete(filesystemPath, true) @@ -486,7 +485,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv sql("DROP TABLE savedJsonTable") intercept[AnalysisException] { read.json( - sessionState.catalog.hiveDefaultTableFilePath(TableIdentifier("savedJsonTable"))) + sessionState.catalog.defaultTablePath(TableIdentifier("savedJsonTable")).toString) } } @@ -756,7 +755,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv serde = None, compressed = false, properties = Map( - "path" -> sessionState.catalog.hiveDefaultTableFilePath(TableIdentifier(tableName))) + "path" -> sessionState.catalog.defaultTablePath(TableIdentifier(tableName)).toString) ), properties = Map( DATASOURCE_PROVIDER -> "json", diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala index 04bc79d43032..f0a995c274b6 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -128,11 +128,11 @@ class HiveDDLSuite dbPath: Option[String] = None): Boolean = { val expectedTablePath = if (dbPath.isEmpty) { - hiveContext.sessionState.catalog.hiveDefaultTableFilePath(tableIdentifier) + hiveContext.sessionState.catalog.defaultTablePath(tableIdentifier) } else { - new Path(new Path(dbPath.get), tableIdentifier.table).toString + new Path(new Path(dbPath.get), tableIdentifier.table) } - val filesystemPath = new Path(expectedTablePath) + val filesystemPath = new Path(expectedTablePath.toString) val fs = filesystemPath.getFileSystem(spark.sessionState.newHadoopConf()) fs.exists(filesystemPath) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala index 81af24979d82..9fc2923bb6fd 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala @@ -22,6 +22,7 @@ import java.io.File import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog.CatalogRelation +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.DataSourceScanExec import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.hive.execution.HiveTableScanExec @@ -448,10 +449,14 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { } } + private def getCachedDataSourceTable(id: TableIdentifier): LogicalPlan = { + sessionState.catalog.asInstanceOf[HiveSessionCatalog].getCachedDataSourceTable(id) + } + test("Caching converted data source Parquet Relations") { def checkCached(tableIdentifier: TableIdentifier): Unit = { // Converted test_parquet should be cached. - sessionState.catalog.getCachedDataSourceTable(tableIdentifier) match { + getCachedDataSourceTable(tableIdentifier) match { case null => fail("Converted test_parquet should be cached in the cache.") case LogicalRelation(_: HadoopFsRelation, _, _) => // OK case other => @@ -479,14 +484,14 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { var tableIdentifier = TableIdentifier("test_insert_parquet", Some("default")) // First, make sure the converted test_parquet is not cached. - assert(sessionState.catalog.getCachedDataSourceTable(tableIdentifier) === null) + assert(getCachedDataSourceTable(tableIdentifier) === null) // Table lookup will make the table cached. table("test_insert_parquet") checkCached(tableIdentifier) // For insert into non-partitioned table, we will do the conversion, // so the converted test_insert_parquet should be cached. sessionState.refreshTable("test_insert_parquet") - assert(sessionState.catalog.getCachedDataSourceTable(tableIdentifier) === null) + assert(getCachedDataSourceTable(tableIdentifier) === null) sql( """ |INSERT INTO TABLE test_insert_parquet @@ -499,7 +504,7 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { sql("select a, b from jt").collect()) // Invalidate the cache. sessionState.refreshTable("test_insert_parquet") - assert(sessionState.catalog.getCachedDataSourceTable(tableIdentifier) === null) + assert(getCachedDataSourceTable(tableIdentifier) === null) // Create a partitioned table. sql( @@ -517,7 +522,7 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { """.stripMargin) tableIdentifier = TableIdentifier("test_parquet_partitioned_cache_test", Some("default")) - assert(sessionState.catalog.getCachedDataSourceTable(tableIdentifier) === null) + assert(getCachedDataSourceTable(tableIdentifier) === null) sql( """ |INSERT INTO TABLE test_parquet_partitioned_cache_test @@ -526,14 +531,14 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { """.stripMargin) // Right now, insert into a partitioned Parquet is not supported in data source Parquet. // So, we expect it is not cached. - assert(sessionState.catalog.getCachedDataSourceTable(tableIdentifier) === null) + assert(getCachedDataSourceTable(tableIdentifier) === null) sql( """ |INSERT INTO TABLE test_parquet_partitioned_cache_test |PARTITION (`date`='2015-04-02') |select a, b from jt """.stripMargin) - assert(sessionState.catalog.getCachedDataSourceTable(tableIdentifier) === null) + assert(getCachedDataSourceTable(tableIdentifier) === null) // Make sure we can cache the partitioned table. table("test_parquet_partitioned_cache_test") @@ -549,7 +554,7 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { """.stripMargin).collect()) sessionState.refreshTable("test_parquet_partitioned_cache_test") - assert(sessionState.catalog.getCachedDataSourceTable(tableIdentifier) === null) + assert(getCachedDataSourceTable(tableIdentifier) === null) dropTables("test_insert_parquet", "test_parquet_partitioned_cache_test") } From 17eddb35a280e77da7520343e0bf2a86b329ed62 Mon Sep 17 00:00:00 2001 From: jerryshao Date: Tue, 28 Mar 2017 10:41:11 -0700 Subject: [PATCH 152/512] [SPARK-19995][YARN] Register tokens to current UGI to avoid re-issuing of tokens in yarn client mode ## What changes were proposed in this pull request? In the current Spark on YARN code, we will obtain tokens from provided services, but we're not going to add these tokens to the current user's credentials. This will make all the following operations to these services still require TGT rather than delegation tokens. This is unnecessary since we already got the tokens, also this will lead to failure in user impersonation scenario, because the TGT is granted by real user, not proxy user. So here changing to put all the tokens to the current UGI, so that following operations to these services will honor tokens rather than TGT, and this will further handle the proxy user issue mentioned above. ## How was this patch tested? Local verified in secure cluster. vanzin tgravescs mridulm dongjoon-hyun please help to review, thanks a lot. Author: jerryshao Closes #17335 from jerryshao/SPARK-19995. --- .../src/main/scala/org/apache/spark/deploy/yarn/Client.scala | 3 +++ 1 file changed, 3 insertions(+) diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index ccb0f8fdbbc2..3218d221143e 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -371,6 +371,9 @@ private[spark] class Client( val nearestTimeOfNextRenewal = credentialManager.obtainCredentials(hadoopConf, credentials) if (credentials != null) { + // Add credentials to current user's UGI, so that following operations don't need to use the + // Kerberos tgt to get delegations again in the client side. + UserGroupInformation.getCurrentUser.addCredentials(credentials) logDebug(YarnSparkHadoopUtil.get.dumpTokens(credentials).mkString("\n")) } From d4fac410e0554b7ccd44be44b7ce2fe07ed7f206 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 28 Mar 2017 11:47:43 -0700 Subject: [PATCH 153/512] [SPARK-20125][SQL] Dataset of type option of map does not work ## What changes were proposed in this pull request? When we build the deserializer expression for map type, we will use `StaticInvoke` to call `ArrayBasedMapData.toScalaMap`, and declare the return type as `scala.collection.immutable.Map`. If the map is inside an Option, we will wrap this `StaticInvoke` with `WrapOption`, which requires the input to be `scala.collect.Map`. Ideally this should be fine, as `scala.collection.immutable.Map` extends `scala.collect.Map`, but our `ObjectType` is too strict about this, this PR fixes it. ## How was this patch tested? new regression test Author: Wenchen Fan Closes #17454 from cloud-fan/map. --- .../main/scala/org/apache/spark/sql/types/ObjectType.scala | 5 +++++ .../src/test/scala/org/apache/spark/sql/DatasetSuite.scala | 6 ++++++ 2 files changed, 11 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ObjectType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ObjectType.scala index b18fba29af0f..2d49fe076786 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ObjectType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ObjectType.scala @@ -44,4 +44,9 @@ case class ObjectType(cls: Class[_]) extends DataType { def asNullable: DataType = this override def simpleString: String = cls.getName + + override def acceptsType(other: DataType): Boolean = other match { + case ObjectType(otherCls) => cls.isAssignableFrom(otherCls) + case _ => false + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 6417e7a8b603..68e071a1a694 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -1154,10 +1154,16 @@ class DatasetSuite extends QueryTest with SharedSQLContext { assert(errMsg3.getMessage.startsWith("cannot have circular references in class, but got the " + "circular reference of class")) } + + test("SPARK-20125: option of map") { + val ds = Seq(WithMapInOption(Some(Map(1 -> 1)))).toDS() + checkDataset(ds, WithMapInOption(Some(Map(1 -> 1)))) + } } case class WithImmutableMap(id: String, map_test: scala.collection.immutable.Map[Long, String]) case class WithMap(id: String, map_test: scala.collection.Map[Long, String]) +case class WithMapInOption(m: Option[scala.collection.Map[Int, Int]]) case class Generic[T](id: T, value: Double) From 92e385e0b55d70a48411e90aa0f2ed141c4d07c8 Mon Sep 17 00:00:00 2001 From: liujianhui Date: Tue, 28 Mar 2017 12:13:45 -0700 Subject: [PATCH 154/512] [SPARK-19868] conflict TasksetManager lead to spark stopped ## What changes were proposed in this pull request? We must set the taskset to zombie before the DAGScheduler handles the taskEnded event. It's possible the taskEnded event will cause the DAGScheduler to launch a new stage attempt (this happens when map output data was lost), and if this happens before the taskSet has been set to zombie, it will appear that we have conflicting task sets. Author: liujianhui Closes #17208 from liujianhuiouc/spark-19868. --- .../spark/scheduler/TaskSetManager.scala | 15 ++++++----- .../spark/scheduler/TaskSetManagerSuite.scala | 27 ++++++++++++++++++- 2 files changed, 34 insertions(+), 8 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index a177aab5f95d..a41b059fa7de 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -713,13 +713,7 @@ private[spark] class TaskSetManager( successfulTaskDurations.insert(info.duration) } removeRunningTask(tid) - // This method is called by "TaskSchedulerImpl.handleSuccessfulTask" which holds the - // "TaskSchedulerImpl" lock until exiting. To avoid the SPARK-7655 issue, we should not - // "deserialize" the value when holding a lock to avoid blocking other threads. So we call - // "result.value()" in "TaskResultGetter.enqueueSuccessfulTask" before reaching here. - // Note: "result.value()" only deserializes the value when it's called at the first time, so - // here "result.value()" just returns the value and won't block other threads. - sched.dagScheduler.taskEnded(tasks(index), Success, result.value(), result.accumUpdates, info) + // Kill any other attempts for the same task (since those are unnecessary now that one // attempt completed successfully). for (attemptInfo <- taskAttempts(index) if attemptInfo.running) { @@ -746,6 +740,13 @@ private[spark] class TaskSetManager( logInfo("Ignoring task-finished event for " + info.id + " in stage " + taskSet.id + " because task " + index + " has already completed successfully") } + // This method is called by "TaskSchedulerImpl.handleSuccessfulTask" which holds the + // "TaskSchedulerImpl" lock until exiting. To avoid the SPARK-7655 issue, we should not + // "deserialize" the value when holding a lock to avoid blocking other threads. So we call + // "result.value()" in "TaskResultGetter.enqueueSuccessfulTask" before reaching here. + // Note: "result.value()" only deserializes the value when it's called at the first time, so + // here "result.value()" just returns the value and won't block other threads. + sched.dagScheduler.taskEnded(tasks(index), Success, result.value(), result.accumUpdates, info) maybeFinishTaskSet() } diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index 132caef0978f..9ca6b8b0fe63 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -22,8 +22,10 @@ import java.util.{Properties, Random} import scala.collection.mutable import scala.collection.mutable.ArrayBuffer -import org.mockito.Matchers.{anyInt, anyString} +import org.mockito.Matchers.{any, anyInt, anyString} import org.mockito.Mockito.{mock, never, spy, verify, when} +import org.mockito.invocation.InvocationOnMock +import org.mockito.stubbing.Answer import org.apache.spark._ import org.apache.spark.internal.config @@ -1056,6 +1058,29 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg assert(manager.isZombie) } + + test("SPARK-19868: DagScheduler only notified of taskEnd when state is ready") { + // dagScheduler.taskEnded() is async, so it may *seem* ok to call it before we've set all + // appropriate state, eg. isZombie. However, this sets up a race that could go the wrong way. + // This is a super-focused regression test which checks the zombie state as soon as + // dagScheduler.taskEnded() is called, to ensure we haven't introduced a race. + sc = new SparkContext("local", "test") + sched = new FakeTaskScheduler(sc, ("exec1", "host1")) + val mockDAGScheduler = mock(classOf[DAGScheduler]) + sched.dagScheduler = mockDAGScheduler + val taskSet = FakeTask.createTaskSet(numTasks = 1, stageId = 0, stageAttemptId = 0) + val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock = new ManualClock(1)) + when(mockDAGScheduler.taskEnded(any(), any(), any(), any(), any())).then(new Answer[Unit] { + override def answer(invocationOnMock: InvocationOnMock): Unit = { + assert(manager.isZombie === true) + } + }) + val taskOption = manager.resourceOffer("exec1", "host1", NO_PREF) + assert(taskOption.isDefined) + // this would fail, inside our mock dag scheduler, if it calls dagScheduler.taskEnded() too soon + manager.handleSuccessfulTask(0, createTaskResult(0)) + } + test("SPARK-17894: Verify TaskSetManagers for different stage attempts have unique names") { sc = new SparkContext("local", "test") sched = new FakeTaskScheduler(sc, ("exec1", "host1")) From 7d432af8f3c47973550ea253dae0c23cd2961bde Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=A2=9C=E5=8F=91=E6=89=8D=EF=BC=88Yan=20Facai=EF=BC=89?= Date: Tue, 28 Mar 2017 16:14:01 -0700 Subject: [PATCH 155/512] [SPARK-20043][ML] DecisionTreeModel: ImpurityCalculator builder fails for uppercase impurity type Gini MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Fix bug: DecisionTreeModel can't recongnize Impurity "Gini" when loading TODO: + [x] add unit test + [x] fix the bug Author: 颜发才(Yan Facai) Closes #17407 from facaiy/BUG/decision_tree_loader_failer_with_Gini_impurity. --- .../spark/mllib/tree/impurity/Impurity.scala | 2 +- .../DecisionTreeClassifierSuite.scala | 14 ++++++++++++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala index a5bdc2c6d2c9..98a3021461eb 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala @@ -184,7 +184,7 @@ private[spark] object ImpurityCalculator { * the given stats. */ def getCalculator(impurity: String, stats: Array[Double]): ImpurityCalculator = { - impurity match { + impurity.toLowerCase match { case "gini" => new GiniCalculator(stats) case "entropy" => new EntropyCalculator(stats) case "variance" => new VarianceCalculator(stats) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala index 10de50306a5c..964fcfbdd87a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala @@ -385,6 +385,20 @@ class DecisionTreeClassifierSuite testEstimatorAndModelReadWrite(dt, continuousData, allParamSettings ++ Map("maxDepth" -> 0), allParamSettings ++ Map("maxDepth" -> 0), checkModelData) } + + test("SPARK-20043: " + + "ImpurityCalculator builder fails for uppercase impurity type Gini in model read/write") { + val rdd = TreeTests.getTreeReadWriteData(sc) + val data: DataFrame = + TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 2) + + val dt = new DecisionTreeClassifier() + .setImpurity("Gini") + .setMaxDepth(2) + val model = dt.fit(data) + + testDefaultReadWrite(model) + } } private[ml] object DecisionTreeClassifierSuite extends SparkFunSuite { From a5c87707eaec5cacdfb703eb396dfc264bc54cda Mon Sep 17 00:00:00 2001 From: Bago Amirbekian Date: Tue, 28 Mar 2017 19:19:16 -0700 Subject: [PATCH 156/512] [SPARK-20040][ML][PYTHON] pyspark wrapper for ChiSquareTest ## What changes were proposed in this pull request? A pyspark wrapper for spark.ml.stat.ChiSquareTest ## How was this patch tested? unit tests doctests Author: Bago Amirbekian Closes #17421 from MrBago/chiSquareTestWrapper. --- dev/sparktestsupport/modules.py | 1 + .../apache/spark/ml/stat/ChiSquareTest.scala | 6 +- python/docs/pyspark.ml.rst | 8 ++ python/pyspark/ml/stat.py | 93 +++++++++++++++++++ python/pyspark/ml/tests.py | 31 +++++-- 5 files changed, 127 insertions(+), 12 deletions(-) create mode 100644 python/pyspark/ml/stat.py diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index eaf1f3a1db2f..246f5188a518 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -431,6 +431,7 @@ def __hash__(self): "pyspark.ml.linalg.__init__", "pyspark.ml.recommendation", "pyspark.ml.regression", + "pyspark.ml.stat", "pyspark.ml.tuning", "pyspark.ml.tests", ], diff --git a/mllib/src/main/scala/org/apache/spark/ml/stat/ChiSquareTest.scala b/mllib/src/main/scala/org/apache/spark/ml/stat/ChiSquareTest.scala index 21eba9a49809..5b38ca73e801 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/stat/ChiSquareTest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/stat/ChiSquareTest.scala @@ -46,9 +46,9 @@ object ChiSquareTest { statistics: Vector) /** - * Conduct Pearson's independence test for every feature against the label across the input RDD. - * For each feature, the (feature, label) pairs are converted into a contingency matrix for which - * the Chi-squared statistic is computed. All label and feature values must be categorical. + * Conduct Pearson's independence test for every feature against the label. For each feature, the + * (feature, label) pairs are converted into a contingency matrix for which the Chi-squared + * statistic is computed. All label and feature values must be categorical. * * The null hypothesis is that the occurrence of the outcomes is statistically independent. * diff --git a/python/docs/pyspark.ml.rst b/python/docs/pyspark.ml.rst index a68183445d78..930646de9cd8 100644 --- a/python/docs/pyspark.ml.rst +++ b/python/docs/pyspark.ml.rst @@ -65,6 +65,14 @@ pyspark.ml.regression module :undoc-members: :inherited-members: +pyspark.ml.stat module +---------------------- + +.. automodule:: pyspark.ml.stat + :members: + :undoc-members: + :inherited-members: + pyspark.ml.tuning module ------------------------ diff --git a/python/pyspark/ml/stat.py b/python/pyspark/ml/stat.py new file mode 100644 index 000000000000..db043ff68fec --- /dev/null +++ b/python/pyspark/ml/stat.py @@ -0,0 +1,93 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from pyspark import since, SparkContext +from pyspark.ml.common import _java2py, _py2java +from pyspark.ml.wrapper import _jvm + + +class ChiSquareTest(object): + """ + .. note:: Experimental + + Conduct Pearson's independence test for every feature against the label. For each feature, + the (feature, label) pairs are converted into a contingency matrix for which the Chi-squared + statistic is computed. All label and feature values must be categorical. + + The null hypothesis is that the occurrence of the outcomes is statistically independent. + + :param dataset: + DataFrame of categorical labels and categorical features. + Real-valued features will be treated as categorical for each distinct value. + :param featuresCol: + Name of features column in dataset, of type `Vector` (`VectorUDT`). + :param labelCol: + Name of label column in dataset, of any numerical type. + :return: + DataFrame containing the test result for every feature against the label. + This DataFrame will contain a single Row with the following fields: + - `pValues: Vector` + - `degreesOfFreedom: Array[Int]` + - `statistics: Vector` + Each of these fields has one value per feature. + + >>> from pyspark.ml.linalg import Vectors + >>> from pyspark.ml.stat import ChiSquareTest + >>> dataset = [[0, Vectors.dense([0, 0, 1])], + ... [0, Vectors.dense([1, 0, 1])], + ... [1, Vectors.dense([2, 1, 1])], + ... [1, Vectors.dense([3, 1, 1])]] + >>> dataset = spark.createDataFrame(dataset, ["label", "features"]) + >>> chiSqResult = ChiSquareTest.test(dataset, 'features', 'label') + >>> chiSqResult.select("degreesOfFreedom").collect()[0] + Row(degreesOfFreedom=[3, 1, 0]) + + .. versionadded:: 2.2.0 + + """ + @staticmethod + @since("2.2.0") + def test(dataset, featuresCol, labelCol): + """ + Perform a Pearson's independence test using dataset. + """ + sc = SparkContext._active_spark_context + javaTestObj = _jvm().org.apache.spark.ml.stat.ChiSquareTest + args = [_py2java(sc, arg) for arg in (dataset, featuresCol, labelCol)] + return _java2py(sc, javaTestObj.test(*args)) + + +if __name__ == "__main__": + import doctest + import pyspark.ml.stat + from pyspark.sql import SparkSession + + globs = pyspark.ml.stat.__dict__.copy() + # The small batch size here ensures that we see multiple batches, + # even in these small test examples: + spark = SparkSession.builder \ + .master("local[2]") \ + .appName("ml.stat tests") \ + .getOrCreate() + sc = spark.sparkContext + globs['sc'] = sc + globs['spark'] = spark + + failure_count, test_count = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) + spark.stop() + if failure_count: + exit(-1) diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 527db9b66793..571ac4bc1c36 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -41,9 +41,7 @@ import tempfile import array as pyarray import numpy as np -from numpy import ( - abs, all, arange, array, array_equal, dot, exp, inf, mean, ones, random, tile, zeros) -from numpy import sum as array_sum +from numpy import abs, all, arange, array, array_equal, inf, ones, tile, zeros import inspect from pyspark import keyword_only, SparkContext @@ -54,20 +52,19 @@ from pyspark.ml.evaluation import BinaryClassificationEvaluator, RegressionEvaluator from pyspark.ml.feature import * from pyspark.ml.fpm import FPGrowth, FPGrowthModel -from pyspark.ml.linalg import ( - DenseMatrix, DenseMatrix, DenseVector, Matrices, MatrixUDT, - SparseMatrix, SparseVector, Vector, VectorUDT, Vectors, _convert_to_vector) +from pyspark.ml.linalg import DenseMatrix, DenseMatrix, DenseVector, Matrices, MatrixUDT, \ + SparseMatrix, SparseVector, Vector, VectorUDT, Vectors from pyspark.ml.param import Param, Params, TypeConverters from pyspark.ml.param.shared import HasInputCol, HasMaxIter, HasSeed from pyspark.ml.recommendation import ALS -from pyspark.ml.regression import ( - DecisionTreeRegressor, GeneralizedLinearRegression, LinearRegression) +from pyspark.ml.regression import DecisionTreeRegressor, GeneralizedLinearRegression, \ + LinearRegression +from pyspark.ml.stat import ChiSquareTest from pyspark.ml.tuning import * from pyspark.ml.wrapper import JavaParams, JavaWrapper from pyspark.serializers import PickleSerializer from pyspark.sql import DataFrame, Row, SparkSession from pyspark.sql.functions import rand -from pyspark.sql.utils import IllegalArgumentException from pyspark.storagelevel import * from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase @@ -1741,6 +1738,22 @@ def test_new_java_array(self): self.assertEqual(_java2py(self.sc, java_array), []) +class ChiSquareTestTests(SparkSessionTestCase): + + def test_chisquaretest(self): + data = [[0, Vectors.dense([0, 1, 2])], + [1, Vectors.dense([1, 1, 1])], + [2, Vectors.dense([2, 1, 0])]] + df = self.spark.createDataFrame(data, ['label', 'feat']) + res = ChiSquareTest.test(df, 'feat', 'label') + # This line is hitting the collect bug described in #17218, commented for now. + # pValues = res.select("degreesOfFreedom").collect()) + self.assertIsInstance(res, DataFrame) + fieldNames = set(field.name for field in res.schema.fields) + expectedFields = ["pValues", "degreesOfFreedom", "statistics"] + self.assertTrue(all(field in fieldNames for field in expectedFields)) + + if __name__ == "__main__": from pyspark.ml.tests import * if xmlrunner: From 9712bd3954c029de5c828f27b57d46e4a6325a38 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 29 Mar 2017 00:02:15 -0700 Subject: [PATCH 157/512] [SPARK-20134][SQL] SQLMetrics.postDriverMetricUpdates to simplify driver side metric updates ## What changes were proposed in this pull request? It is not super intuitive how to update SQLMetric on the driver side. This patch introduces a new SQLMetrics.postDriverMetricUpdates function to do that, and adds documentation to make it more obvious. ## How was this patch tested? Updated a test case to use this method. Author: Reynold Xin Closes #17464 from rxin/SPARK-20134. --- .../execution/basicPhysicalOperators.scala | 8 +------- .../exchange/BroadcastExchangeExec.scala | 8 +------- .../sql/execution/metric/SQLMetrics.scala | 20 +++++++++++++++++++ .../spark/sql/execution/ui/SQLListener.scala | 7 +++++++ .../sql/execution/ui/SQLListenerSuite.scala | 8 +++++--- 5 files changed, 34 insertions(+), 17 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index d876688a8aab..66a8e044ab87 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -628,13 +628,7 @@ case class SubqueryExec(name: String, child: SparkPlan) extends UnaryExecNode { val dataSize = rows.map(_.asInstanceOf[UnsafeRow].getSizeInBytes.toLong).sum longMetric("dataSize") += dataSize - // There are some cases we don't care about the metrics and call `SparkPlan.doExecute` - // directly without setting an execution id. We should be tolerant to it. - if (executionId != null) { - sparkContext.listenerBus.post(SparkListenerDriverAccumUpdates( - executionId.toLong, metrics.values.map(m => m.id -> m.value).toSeq)) - } - + SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, metrics.values.toSeq) rows } }(SubqueryExec.executionContext) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala index 7be5d31d4a76..efcaca9338ad 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala @@ -97,13 +97,7 @@ case class BroadcastExchangeExec( val broadcasted = sparkContext.broadcast(relation) longMetric("broadcastTime") += (System.nanoTime() - beforeBroadcast) / 1000000 - // There are some cases we don't care about the metrics and call `SparkPlan.doExecute` - // directly without setting an execution id. We should be tolerant to it. - if (executionId != null) { - sparkContext.listenerBus.post(SparkListenerDriverAccumUpdates( - executionId.toLong, metrics.values.map(m => m.id -> m.value).toSeq)) - } - + SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, metrics.values.toSeq) broadcasted } catch { case oe: OutOfMemoryError => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala index dbc27d8b237f..ef982a4ebd10 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala @@ -22,9 +22,15 @@ import java.util.Locale import org.apache.spark.SparkContext import org.apache.spark.scheduler.AccumulableInfo +import org.apache.spark.sql.execution.ui.SparkListenerDriverAccumUpdates import org.apache.spark.util.{AccumulatorContext, AccumulatorV2, Utils} +/** + * A metric used in a SQL query plan. This is implemented as an [[AccumulatorV2]]. Updates on + * the executor side are automatically propagated and shown in the SQL UI through metrics. Updates + * on the driver side must be explicitly posted using [[SQLMetrics.postDriverMetricUpdates()]]. + */ class SQLMetric(val metricType: String, initValue: Long = 0L) extends AccumulatorV2[Long, Long] { // This is a workaround for SPARK-11013. // We may use -1 as initial value of the accumulator, if the accumulator is valid, we will @@ -126,4 +132,18 @@ object SQLMetrics { s"\n$sum ($min, $med, $max)" } } + + /** + * Updates metrics based on the driver side value. This is useful for certain metrics that + * are only updated on the driver, e.g. subquery execution time, or number of files. + */ + def postDriverMetricUpdates( + sc: SparkContext, executionId: String, metrics: Seq[SQLMetric]): Unit = { + // There are some cases we don't care about the metrics and call `SparkPlan.doExecute` + // directly without setting an execution id. We should be tolerant to it. + if (executionId != null) { + sc.listenerBus.post( + SparkListenerDriverAccumUpdates(executionId.toLong, metrics.map(m => m.id -> m.value))) + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala index 12d3bc9281f3..b4a91230a001 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala @@ -47,6 +47,13 @@ case class SparkListenerSQLExecutionStart( case class SparkListenerSQLExecutionEnd(executionId: Long, time: Long) extends SparkListenerEvent +/** + * A message used to update SQL metric value for driver-side updates (which doesn't get reflected + * automatically). + * + * @param executionId The execution id for a query, so we can find the query plan. + * @param accumUpdates Map from accumulator id to the metric value (metrics are always 64-bit ints). + */ @DeveloperApi case class SparkListenerDriverAccumUpdates( executionId: Long, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala index e41c00ecec27..e6cd41e4facf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala @@ -477,9 +477,11 @@ private case class MyPlan(sc: SparkContext, expectedValue: Long) extends LeafExe override def doExecute(): RDD[InternalRow] = { longMetric("dummy") += expectedValue - sc.listenerBus.post(SparkListenerDriverAccumUpdates( - sc.getLocalProperty(SQLExecution.EXECUTION_ID_KEY).toLong, - metrics.values.map(m => m.id -> m.value).toSeq)) + + SQLMetrics.postDriverMetricUpdates( + sc, + sc.getLocalProperty(SQLExecution.EXECUTION_ID_KEY), + metrics.values.toSeq) sc.emptyRDD } } From b56ad2b1ec19fd60fa9d4926d12244fd3f56aca4 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Wed, 29 Mar 2017 20:27:41 +0800 Subject: [PATCH 158/512] [SPARK-19556][CORE] Do not encrypt block manager data in memory. This change modifies the way block data is encrypted to make the more common cases faster, while penalizing an edge case. As a side effect of the change, all data that goes through the block manager is now encrypted only when needed, including the previous path (broadcast variables) where that did not happen. The way the change works is by not encrypting data that is stored in memory; so if a serialized block is in memory, it will only be encrypted once it is evicted to disk. The penalty comes when transferring that encrypted data from disk. If the data ends up in memory again, it is as efficient as before; but if the evicted block needs to be transferred directly to a remote executor, then there's now a performance penalty, since the code now uses a custom FileRegion implementation to decrypt the data before transferring. This also means that block data transferred between executors now is not encrypted (and thus relies on the network library encryption support for secrecy). Shuffle blocks are still transferred in encrypted form, since they're handled in a slightly different way by the code. This also keeps compatibility with existing external shuffle services, which transfer encrypted shuffle blocks, and avoids having to make the external service aware of encryption at all. The serialization and deserialization APIs in the SerializerManager now do not do encryption automatically; callers need to explicitly wrap their streams with an appropriate crypto stream before using those. As a result of these changes, some of the workarounds added in SPARK-19520 are removed here. Testing: a new trait ("EncryptionFunSuite") was added that provides an easy way to run a test twice, with encryption on and off; broadcast, block manager and caching tests were modified to use this new trait so that the existing tests exercise both encrypted and non-encrypted paths. I also ran some applications with encryption turned on to verify that they still work, including streaming tests that failed without the fix for SPARK-19520. Author: Marcelo Vanzin Closes #17295 from vanzin/SPARK-19556. --- .../apache/spark/network/util/JavaUtils.java | 15 ++ .../spark/broadcast/TorrentBroadcast.scala | 35 +-- .../spark/security/CryptoStreamUtils.scala | 87 ++++++- .../spark/serializer/SerializerManager.scala | 24 +- .../apache/spark/storage/BlockManager.scala | 172 ++++++++----- .../storage/BlockManagerManagedBuffer.scala | 33 ++- .../org/apache/spark/storage/DiskStore.scala | 236 +++++++++++++++--- .../apache/spark/storage/StorageUtils.scala | 32 --- .../spark/storage/memory/MemoryStore.scala | 2 +- .../spark/util/io/ChunkedByteBuffer.scala | 14 -- .../org/apache/spark/DistributedSuite.scala | 12 +- .../spark/broadcast/BroadcastSuite.scala | 16 +- .../security/CryptoStreamUtilsSuite.scala | 46 +++- .../spark/security/EncryptionFunSuite.scala | 39 +++ .../spark/storage/BlockManagerSuite.scala | 77 +++--- .../apache/spark/storage/DiskStoreSuite.scala | 115 ++++++++- .../rdd/WriteAheadLogBackedBlockRDD.scala | 6 +- .../receiver/ReceivedBlockHandler.scala | 11 +- .../streaming/ReceivedBlockHandlerSuite.scala | 5 +- .../WriteAheadLogBackedBlockRDDSuite.scala | 3 +- 20 files changed, 710 insertions(+), 270 deletions(-) create mode 100644 core/src/test/scala/org/apache/spark/security/EncryptionFunSuite.scala diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/JavaUtils.java b/common/network-common/src/main/java/org/apache/spark/network/util/JavaUtils.java index f3eaf22c0166..51d7fda0cb26 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/util/JavaUtils.java +++ b/common/network-common/src/main/java/org/apache/spark/network/util/JavaUtils.java @@ -18,9 +18,11 @@ package org.apache.spark.network.util; import java.io.Closeable; +import java.io.EOFException; import java.io.File; import java.io.IOException; import java.nio.ByteBuffer; +import java.nio.channels.ReadableByteChannel; import java.nio.charset.StandardCharsets; import java.util.concurrent.TimeUnit; import java.util.regex.Matcher; @@ -344,4 +346,17 @@ public static byte[] bufferToArray(ByteBuffer buffer) { } } + /** + * Fills a buffer with data read from the channel. + */ + public static void readFully(ReadableByteChannel channel, ByteBuffer dst) throws IOException { + int expected = dst.remaining(); + while (dst.hasRemaining()) { + if (channel.read(dst) < 0) { + throw new EOFException(String.format("Not enough bytes in channel (expected %d).", + expected)); + } + } + } + } diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala index 22d01c47e645..039df75ce74f 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala @@ -29,7 +29,7 @@ import org.apache.spark._ import org.apache.spark.internal.Logging import org.apache.spark.io.CompressionCodec import org.apache.spark.serializer.Serializer -import org.apache.spark.storage.{BlockId, BroadcastBlockId, StorageLevel} +import org.apache.spark.storage._ import org.apache.spark.util.{ByteBufferInputStream, Utils} import org.apache.spark.util.io.{ChunkedByteBuffer, ChunkedByteBufferOutputStream} @@ -141,10 +141,10 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long) } /** Fetch torrent blocks from the driver and/or other executors. */ - private def readBlocks(): Array[ChunkedByteBuffer] = { + private def readBlocks(): Array[BlockData] = { // Fetch chunks of data. Note that all these chunks are stored in the BlockManager and reported // to the driver, so other executors can pull these chunks from this executor as well. - val blocks = new Array[ChunkedByteBuffer](numBlocks) + val blocks = new Array[BlockData](numBlocks) val bm = SparkEnv.get.blockManager for (pid <- Random.shuffle(Seq.range(0, numBlocks))) { @@ -173,7 +173,7 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long) throw new SparkException( s"Failed to store $pieceId of $broadcastId in local BlockManager") } - blocks(pid) = b + blocks(pid) = new ByteBufferBlockData(b, true) case None => throw new SparkException(s"Failed to get $pieceId of $broadcastId") } @@ -219,18 +219,22 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long) case None => logInfo("Started reading broadcast variable " + id) val startTimeMs = System.currentTimeMillis() - val blocks = readBlocks().flatMap(_.getChunks()) + val blocks = readBlocks() logInfo("Reading broadcast variable " + id + " took" + Utils.getUsedTimeMs(startTimeMs)) - val obj = TorrentBroadcast.unBlockifyObject[T]( - blocks, SparkEnv.get.serializer, compressionCodec) - // Store the merged copy in BlockManager so other tasks on this executor don't - // need to re-fetch it. - val storageLevel = StorageLevel.MEMORY_AND_DISK - if (!blockManager.putSingle(broadcastId, obj, storageLevel, tellMaster = false)) { - throw new SparkException(s"Failed to store $broadcastId in BlockManager") + try { + val obj = TorrentBroadcast.unBlockifyObject[T]( + blocks.map(_.toInputStream()), SparkEnv.get.serializer, compressionCodec) + // Store the merged copy in BlockManager so other tasks on this executor don't + // need to re-fetch it. + val storageLevel = StorageLevel.MEMORY_AND_DISK + if (!blockManager.putSingle(broadcastId, obj, storageLevel, tellMaster = false)) { + throw new SparkException(s"Failed to store $broadcastId in BlockManager") + } + obj + } finally { + blocks.foreach(_.dispose()) } - obj } } } @@ -277,12 +281,11 @@ private object TorrentBroadcast extends Logging { } def unBlockifyObject[T: ClassTag]( - blocks: Array[ByteBuffer], + blocks: Array[InputStream], serializer: Serializer, compressionCodec: Option[CompressionCodec]): T = { require(blocks.nonEmpty, "Cannot unblockify an empty array of blocks") - val is = new SequenceInputStream( - blocks.iterator.map(new ByteBufferInputStream(_)).asJavaEnumeration) + val is = new SequenceInputStream(blocks.iterator.asJavaEnumeration) val in: InputStream = compressionCodec.map(c => c.compressedInputStream(is)).getOrElse(is) val ser = serializer.newInstance() val serIn = ser.deserializeStream(in) diff --git a/core/src/main/scala/org/apache/spark/security/CryptoStreamUtils.scala b/core/src/main/scala/org/apache/spark/security/CryptoStreamUtils.scala index cdd3b8d8512b..78dabb42ac9d 100644 --- a/core/src/main/scala/org/apache/spark/security/CryptoStreamUtils.scala +++ b/core/src/main/scala/org/apache/spark/security/CryptoStreamUtils.scala @@ -16,20 +16,23 @@ */ package org.apache.spark.security -import java.io.{InputStream, OutputStream} +import java.io.{EOFException, InputStream, OutputStream} +import java.nio.ByteBuffer +import java.nio.channels.{ReadableByteChannel, WritableByteChannel} import java.util.Properties import javax.crypto.KeyGenerator import javax.crypto.spec.{IvParameterSpec, SecretKeySpec} import scala.collection.JavaConverters._ +import com.google.common.io.ByteStreams import org.apache.commons.crypto.random._ import org.apache.commons.crypto.stream._ import org.apache.spark.SparkConf import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ -import org.apache.spark.network.util.CryptoUtils +import org.apache.spark.network.util.{CryptoUtils, JavaUtils} /** * A util class for manipulating IO encryption and decryption streams. @@ -48,12 +51,27 @@ private[spark] object CryptoStreamUtils extends Logging { os: OutputStream, sparkConf: SparkConf, key: Array[Byte]): OutputStream = { - val properties = toCryptoConf(sparkConf) - val iv = createInitializationVector(properties) + val params = new CryptoParams(key, sparkConf) + val iv = createInitializationVector(params.conf) os.write(iv) - val transformationStr = sparkConf.get(IO_CRYPTO_CIPHER_TRANSFORMATION) - new CryptoOutputStream(transformationStr, properties, os, - new SecretKeySpec(key, "AES"), new IvParameterSpec(iv)) + new CryptoOutputStream(params.transformation, params.conf, os, params.keySpec, + new IvParameterSpec(iv)) + } + + /** + * Wrap a `WritableByteChannel` for encryption. + */ + def createWritableChannel( + channel: WritableByteChannel, + sparkConf: SparkConf, + key: Array[Byte]): WritableByteChannel = { + val params = new CryptoParams(key, sparkConf) + val iv = createInitializationVector(params.conf) + val helper = new CryptoHelperChannel(channel) + + helper.write(ByteBuffer.wrap(iv)) + new CryptoOutputStream(params.transformation, params.conf, helper, params.keySpec, + new IvParameterSpec(iv)) } /** @@ -63,12 +81,27 @@ private[spark] object CryptoStreamUtils extends Logging { is: InputStream, sparkConf: SparkConf, key: Array[Byte]): InputStream = { - val properties = toCryptoConf(sparkConf) val iv = new Array[Byte](IV_LENGTH_IN_BYTES) - is.read(iv, 0, iv.length) - val transformationStr = sparkConf.get(IO_CRYPTO_CIPHER_TRANSFORMATION) - new CryptoInputStream(transformationStr, properties, is, - new SecretKeySpec(key, "AES"), new IvParameterSpec(iv)) + ByteStreams.readFully(is, iv) + val params = new CryptoParams(key, sparkConf) + new CryptoInputStream(params.transformation, params.conf, is, params.keySpec, + new IvParameterSpec(iv)) + } + + /** + * Wrap a `ReadableByteChannel` for decryption. + */ + def createReadableChannel( + channel: ReadableByteChannel, + sparkConf: SparkConf, + key: Array[Byte]): ReadableByteChannel = { + val iv = new Array[Byte](IV_LENGTH_IN_BYTES) + val buf = ByteBuffer.wrap(iv) + JavaUtils.readFully(channel, buf) + + val params = new CryptoParams(key, sparkConf) + new CryptoInputStream(params.transformation, params.conf, channel, params.keySpec, + new IvParameterSpec(iv)) } def toCryptoConf(conf: SparkConf): Properties = { @@ -102,4 +135,34 @@ private[spark] object CryptoStreamUtils extends Logging { } iv } + + /** + * This class is a workaround for CRYPTO-125, that forces all bytes to be written to the + * underlying channel. Since the callers of this API are using blocking I/O, there are no + * concerns with regards to CPU usage here. + */ + private class CryptoHelperChannel(sink: WritableByteChannel) extends WritableByteChannel { + + override def write(src: ByteBuffer): Int = { + val count = src.remaining() + while (src.hasRemaining()) { + sink.write(src) + } + count + } + + override def isOpen(): Boolean = sink.isOpen() + + override def close(): Unit = sink.close() + + } + + private class CryptoParams(key: Array[Byte], sparkConf: SparkConf) { + + val keySpec = new SecretKeySpec(key, "AES") + val transformation = sparkConf.get(IO_CRYPTO_CIPHER_TRANSFORMATION) + val conf = toCryptoConf(sparkConf) + + } + } diff --git a/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala b/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala index 96b288b9cfb8..bb7ed8709ba8 100644 --- a/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala +++ b/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala @@ -148,14 +148,14 @@ private[spark] class SerializerManager( /** * Wrap an output stream for compression if block compression is enabled for its block type */ - private[this] def wrapForCompression(blockId: BlockId, s: OutputStream): OutputStream = { + def wrapForCompression(blockId: BlockId, s: OutputStream): OutputStream = { if (shouldCompress(blockId)) compressionCodec.compressedOutputStream(s) else s } /** * Wrap an input stream for compression if block compression is enabled for its block type */ - private[this] def wrapForCompression(blockId: BlockId, s: InputStream): InputStream = { + def wrapForCompression(blockId: BlockId, s: InputStream): InputStream = { if (shouldCompress(blockId)) compressionCodec.compressedInputStream(s) else s } @@ -167,30 +167,26 @@ private[spark] class SerializerManager( val byteStream = new BufferedOutputStream(outputStream) val autoPick = !blockId.isInstanceOf[StreamBlockId] val ser = getSerializer(implicitly[ClassTag[T]], autoPick).newInstance() - ser.serializeStream(wrapStream(blockId, byteStream)).writeAll(values).close() + ser.serializeStream(wrapForCompression(blockId, byteStream)).writeAll(values).close() } /** Serializes into a chunked byte buffer. */ def dataSerialize[T: ClassTag]( blockId: BlockId, - values: Iterator[T], - allowEncryption: Boolean = true): ChunkedByteBuffer = { - dataSerializeWithExplicitClassTag(blockId, values, implicitly[ClassTag[T]], - allowEncryption = allowEncryption) + values: Iterator[T]): ChunkedByteBuffer = { + dataSerializeWithExplicitClassTag(blockId, values, implicitly[ClassTag[T]]) } /** Serializes into a chunked byte buffer. */ def dataSerializeWithExplicitClassTag( blockId: BlockId, values: Iterator[_], - classTag: ClassTag[_], - allowEncryption: Boolean = true): ChunkedByteBuffer = { + classTag: ClassTag[_]): ChunkedByteBuffer = { val bbos = new ChunkedByteBufferOutputStream(1024 * 1024 * 4, ByteBuffer.allocate) val byteStream = new BufferedOutputStream(bbos) val autoPick = !blockId.isInstanceOf[StreamBlockId] val ser = getSerializer(classTag, autoPick).newInstance() - val encrypted = if (allowEncryption) wrapForEncryption(byteStream) else byteStream - ser.serializeStream(wrapForCompression(blockId, encrypted)).writeAll(values).close() + ser.serializeStream(wrapForCompression(blockId, byteStream)).writeAll(values).close() bbos.toChunkedByteBuffer } @@ -200,15 +196,13 @@ private[spark] class SerializerManager( */ def dataDeserializeStream[T]( blockId: BlockId, - inputStream: InputStream, - maybeEncrypted: Boolean = true) + inputStream: InputStream) (classTag: ClassTag[T]): Iterator[T] = { val stream = new BufferedInputStream(inputStream) val autoPick = !blockId.isInstanceOf[StreamBlockId] - val decrypted = if (maybeEncrypted) wrapForEncryption(inputStream) else inputStream getSerializer(classTag, autoPick) .newInstance() - .deserializeStream(wrapForCompression(blockId, decrypted)) + .deserializeStream(wrapForCompression(blockId, inputStream)) .asIterator.asInstanceOf[Iterator[T]] } } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 991346a40af4..fcda9fa65303 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -19,6 +19,7 @@ package org.apache.spark.storage import java.io._ import java.nio.ByteBuffer +import java.nio.channels.Channels import scala.collection.mutable import scala.collection.mutable.HashMap @@ -35,7 +36,7 @@ import org.apache.spark.executor.{DataReadMethod, ShuffleWriteMetrics} import org.apache.spark.internal.Logging import org.apache.spark.memory.{MemoryManager, MemoryMode} import org.apache.spark.network._ -import org.apache.spark.network.buffer.{ManagedBuffer, NettyManagedBuffer} +import org.apache.spark.network.buffer.ManagedBuffer import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.network.shuffle.ExternalShuffleClient import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo @@ -55,6 +56,55 @@ private[spark] class BlockResult( val readMethod: DataReadMethod.Value, val bytes: Long) +/** + * Abstracts away how blocks are stored and provides different ways to read the underlying block + * data. Callers should call [[dispose()]] when they're done with the block. + */ +private[spark] trait BlockData { + + def toInputStream(): InputStream + + /** + * Returns a Netty-friendly wrapper for the block's data. + * + * @see [[ManagedBuffer#convertToNetty()]] + */ + def toNetty(): Object + + def toChunkedByteBuffer(allocator: Int => ByteBuffer): ChunkedByteBuffer + + def toByteBuffer(): ByteBuffer + + def size: Long + + def dispose(): Unit + +} + +private[spark] class ByteBufferBlockData( + val buffer: ChunkedByteBuffer, + val shouldDispose: Boolean) extends BlockData { + + override def toInputStream(): InputStream = buffer.toInputStream(dispose = false) + + override def toNetty(): Object = buffer.toNetty + + override def toChunkedByteBuffer(allocator: Int => ByteBuffer): ChunkedByteBuffer = { + buffer.copy(allocator) + } + + override def toByteBuffer(): ByteBuffer = buffer.toByteBuffer + + override def size: Long = buffer.size + + override def dispose(): Unit = { + if (shouldDispose) { + buffer.dispose() + } + } + +} + /** * Manager running on every node (driver and executors) which provides interfaces for putting and * retrieving blocks both locally and remotely into various stores (memory, disk, and off-heap). @@ -94,7 +144,7 @@ private[spark] class BlockManager( // Actual storage of where blocks are kept private[spark] val memoryStore = new MemoryStore(conf, blockInfoManager, serializerManager, memoryManager, this) - private[spark] val diskStore = new DiskStore(conf, diskBlockManager) + private[spark] val diskStore = new DiskStore(conf, diskBlockManager, securityManager) memoryManager.setMemoryStore(memoryStore) // Note: depending on the memory manager, `maxMemory` may actually vary over time. @@ -304,7 +354,8 @@ private[spark] class BlockManager( shuffleManager.shuffleBlockResolver.getBlockData(blockId.asInstanceOf[ShuffleBlockId]) } else { getLocalBytes(blockId) match { - case Some(buffer) => new BlockManagerManagedBuffer(blockInfoManager, blockId, buffer) + case Some(blockData) => + new BlockManagerManagedBuffer(blockInfoManager, blockId, blockData, true) case None => // If this block manager receives a request for a block that it doesn't have then it's // likely that the master has outdated block statuses for this block. Therefore, we send @@ -463,21 +514,22 @@ private[spark] class BlockManager( val ci = CompletionIterator[Any, Iterator[Any]](iter, releaseLock(blockId)) Some(new BlockResult(ci, DataReadMethod.Memory, info.size)) } else if (level.useDisk && diskStore.contains(blockId)) { + val diskData = diskStore.getBytes(blockId) val iterToReturn: Iterator[Any] = { - val diskBytes = diskStore.getBytes(blockId) if (level.deserialized) { val diskValues = serializerManager.dataDeserializeStream( blockId, - diskBytes.toInputStream(dispose = true))(info.classTag) + diskData.toInputStream())(info.classTag) maybeCacheDiskValuesInMemory(info, blockId, level, diskValues) } else { - val stream = maybeCacheDiskBytesInMemory(info, blockId, level, diskBytes) - .map {_.toInputStream(dispose = false)} - .getOrElse { diskBytes.toInputStream(dispose = true) } + val stream = maybeCacheDiskBytesInMemory(info, blockId, level, diskData) + .map { _.toInputStream(dispose = false) } + .getOrElse { diskData.toInputStream() } serializerManager.dataDeserializeStream(blockId, stream)(info.classTag) } } - val ci = CompletionIterator[Any, Iterator[Any]](iterToReturn, releaseLock(blockId)) + val ci = CompletionIterator[Any, Iterator[Any]](iterToReturn, + releaseLockAndDispose(blockId, diskData)) Some(new BlockResult(ci, DataReadMethod.Disk, info.size)) } else { handleLocalReadFailure(blockId) @@ -488,7 +540,7 @@ private[spark] class BlockManager( /** * Get block from the local block manager as serialized bytes. */ - def getLocalBytes(blockId: BlockId): Option[ChunkedByteBuffer] = { + def getLocalBytes(blockId: BlockId): Option[BlockData] = { logDebug(s"Getting local block $blockId as bytes") // As an optimization for map output fetches, if the block is for a shuffle, return it // without acquiring a lock; the disk store never deletes (recent) items so this should work @@ -496,9 +548,9 @@ private[spark] class BlockManager( val shuffleBlockResolver = shuffleManager.shuffleBlockResolver // TODO: This should gracefully handle case where local block is not available. Currently // downstream code will throw an exception. - Option( - new ChunkedByteBuffer( - shuffleBlockResolver.getBlockData(blockId.asInstanceOf[ShuffleBlockId]).nioByteBuffer())) + val buf = new ChunkedByteBuffer( + shuffleBlockResolver.getBlockData(blockId.asInstanceOf[ShuffleBlockId]).nioByteBuffer()) + Some(new ByteBufferBlockData(buf, true)) } else { blockInfoManager.lockForReading(blockId).map { info => doGetLocalBytes(blockId, info) } } @@ -510,7 +562,7 @@ private[spark] class BlockManager( * Must be called while holding a read lock on the block. * Releases the read lock upon exception; keeps the read lock upon successful return. */ - private def doGetLocalBytes(blockId: BlockId, info: BlockInfo): ChunkedByteBuffer = { + private def doGetLocalBytes(blockId: BlockId, info: BlockInfo): BlockData = { val level = info.level logDebug(s"Level for block $blockId is $level") // In order, try to read the serialized bytes from memory, then from disk, then fall back to @@ -525,17 +577,19 @@ private[spark] class BlockManager( diskStore.getBytes(blockId) } else if (level.useMemory && memoryStore.contains(blockId)) { // The block was not found on disk, so serialize an in-memory copy: - serializerManager.dataSerializeWithExplicitClassTag( - blockId, memoryStore.getValues(blockId).get, info.classTag) + new ByteBufferBlockData(serializerManager.dataSerializeWithExplicitClassTag( + blockId, memoryStore.getValues(blockId).get, info.classTag), true) } else { handleLocalReadFailure(blockId) } } else { // storage level is serialized if (level.useMemory && memoryStore.contains(blockId)) { - memoryStore.getBytes(blockId).get + new ByteBufferBlockData(memoryStore.getBytes(blockId).get, false) } else if (level.useDisk && diskStore.contains(blockId)) { - val diskBytes = diskStore.getBytes(blockId) - maybeCacheDiskBytesInMemory(info, blockId, level, diskBytes).getOrElse(diskBytes) + val diskData = diskStore.getBytes(blockId) + maybeCacheDiskBytesInMemory(info, blockId, level, diskData) + .map(new ByteBufferBlockData(_, false)) + .getOrElse(diskData) } else { handleLocalReadFailure(blockId) } @@ -761,43 +815,15 @@ private[spark] class BlockManager( * '''Important!''' Callers must not mutate or release the data buffer underlying `bytes`. Doing * so may corrupt or change the data stored by the `BlockManager`. * - * @param encrypt If true, asks the block manager to encrypt the data block before storing, - * when I/O encryption is enabled. This is required for blocks that have been - * read from unencrypted sources, since all the BlockManager read APIs - * automatically do decryption. * @return true if the block was stored or false if an error occurred. */ def putBytes[T: ClassTag]( blockId: BlockId, bytes: ChunkedByteBuffer, level: StorageLevel, - tellMaster: Boolean = true, - encrypt: Boolean = false): Boolean = { + tellMaster: Boolean = true): Boolean = { require(bytes != null, "Bytes is null") - - val bytesToStore = - if (encrypt && securityManager.ioEncryptionKey.isDefined) { - try { - val data = bytes.toByteBuffer - val in = new ByteBufferInputStream(data) - val byteBufOut = new ByteBufferOutputStream(data.remaining()) - val out = CryptoStreamUtils.createCryptoOutputStream(byteBufOut, conf, - securityManager.ioEncryptionKey.get) - try { - ByteStreams.copy(in, out) - } finally { - in.close() - out.close() - } - new ChunkedByteBuffer(byteBufOut.toByteBuffer) - } finally { - bytes.dispose() - } - } else { - bytes - } - - doPutBytes(blockId, bytesToStore, level, implicitly[ClassTag[T]], tellMaster) + doPutBytes(blockId, bytes, level, implicitly[ClassTag[T]], tellMaster) } /** @@ -828,8 +854,9 @@ private[spark] class BlockManager( val replicationFuture = if (level.replication > 1) { Future { // This is a blocking action and should run in futureExecutionContext which is a cached - // thread pool - replicate(blockId, bytes, level, classTag) + // thread pool. The ByteBufferBlockData wrapper is not disposed of to avoid releasing + // buffers that are owned by the caller. + replicate(blockId, new ByteBufferBlockData(bytes, false), level, classTag) }(futureExecutionContext) } else { null @@ -1008,8 +1035,9 @@ private[spark] class BlockManager( // Not enough space to unroll this block; drop to disk if applicable if (level.useDisk) { logWarning(s"Persisting block $blockId to disk instead.") - diskStore.put(blockId) { fileOutputStream => - serializerManager.dataSerializeStream(blockId, fileOutputStream, iter)(classTag) + diskStore.put(blockId) { channel => + val out = Channels.newOutputStream(channel) + serializerManager.dataSerializeStream(blockId, out, iter)(classTag) } size = diskStore.getSize(blockId) } else { @@ -1024,8 +1052,9 @@ private[spark] class BlockManager( // Not enough space to unroll this block; drop to disk if applicable if (level.useDisk) { logWarning(s"Persisting block $blockId to disk instead.") - diskStore.put(blockId) { fileOutputStream => - partiallySerializedValues.finishWritingToStream(fileOutputStream) + diskStore.put(blockId) { channel => + val out = Channels.newOutputStream(channel) + partiallySerializedValues.finishWritingToStream(out) } size = diskStore.getSize(blockId) } else { @@ -1035,8 +1064,9 @@ private[spark] class BlockManager( } } else if (level.useDisk) { - diskStore.put(blockId) { fileOutputStream => - serializerManager.dataSerializeStream(blockId, fileOutputStream, iterator())(classTag) + diskStore.put(blockId) { channel => + val out = Channels.newOutputStream(channel) + serializerManager.dataSerializeStream(blockId, out, iterator())(classTag) } size = diskStore.getSize(blockId) } @@ -1065,7 +1095,7 @@ private[spark] class BlockManager( try { replicate(blockId, bytesToReplicate, level, remoteClassTag) } finally { - bytesToReplicate.unmap() + bytesToReplicate.dispose() } logDebug("Put block %s remotely took %s" .format(blockId, Utils.getUsedTimeMs(remoteStartTime))) @@ -1089,29 +1119,29 @@ private[spark] class BlockManager( blockInfo: BlockInfo, blockId: BlockId, level: StorageLevel, - diskBytes: ChunkedByteBuffer): Option[ChunkedByteBuffer] = { + diskData: BlockData): Option[ChunkedByteBuffer] = { require(!level.deserialized) if (level.useMemory) { // Synchronize on blockInfo to guard against a race condition where two readers both try to // put values read from disk into the MemoryStore. blockInfo.synchronized { if (memoryStore.contains(blockId)) { - diskBytes.dispose() + diskData.dispose() Some(memoryStore.getBytes(blockId).get) } else { val allocator = level.memoryMode match { case MemoryMode.ON_HEAP => ByteBuffer.allocate _ case MemoryMode.OFF_HEAP => Platform.allocateDirectBuffer _ } - val putSucceeded = memoryStore.putBytes(blockId, diskBytes.size, level.memoryMode, () => { + val putSucceeded = memoryStore.putBytes(blockId, diskData.size, level.memoryMode, () => { // https://issues.apache.org/jira/browse/SPARK-6076 // If the file size is bigger than the free memory, OOM will happen. So if we // cannot put it into MemoryStore, copyForMemory should not be created. That's why // this action is put into a `() => ChunkedByteBuffer` and created lazily. - diskBytes.copy(allocator) + diskData.toChunkedByteBuffer(allocator) }) if (putSucceeded) { - diskBytes.dispose() + diskData.dispose() Some(memoryStore.getBytes(blockId).get) } else { None @@ -1203,7 +1233,7 @@ private[spark] class BlockManager( replicate(blockId, data, storageLevel, info.classTag, existingReplicas) } finally { logDebug(s"Releasing lock for $blockId") - releaseLock(blockId) + releaseLockAndDispose(blockId, data) } } } @@ -1214,7 +1244,7 @@ private[spark] class BlockManager( */ private def replicate( blockId: BlockId, - data: ChunkedByteBuffer, + data: BlockData, level: StorageLevel, classTag: ClassTag[_], existingReplicas: Set[BlockManagerId] = Set.empty): Unit = { @@ -1256,7 +1286,7 @@ private[spark] class BlockManager( peer.port, peer.executorId, blockId, - new NettyManagedBuffer(data.toNetty), + new BlockManagerManagedBuffer(blockInfoManager, blockId, data, false), tLevel, classTag) logTrace(s"Replicated $blockId of ${data.size} bytes to $peer" + @@ -1339,10 +1369,11 @@ private[spark] class BlockManager( logInfo(s"Writing block $blockId to disk") data() match { case Left(elements) => - diskStore.put(blockId) { fileOutputStream => + diskStore.put(blockId) { channel => + val out = Channels.newOutputStream(channel) serializerManager.dataSerializeStream( blockId, - fileOutputStream, + out, elements.toIterator)(info.classTag.asInstanceOf[ClassTag[T]]) } case Right(bytes) => @@ -1434,6 +1465,11 @@ private[spark] class BlockManager( } } + def releaseLockAndDispose(blockId: BlockId, data: BlockData): Unit = { + blockInfoManager.unlock(blockId) + data.dispose() + } + def stop(): Unit = { blockTransferService.close() if (shuffleClient ne blockTransferService) { diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerManagedBuffer.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerManagedBuffer.scala index f66f94279855..1ea0d378cbe8 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerManagedBuffer.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerManagedBuffer.scala @@ -17,31 +17,52 @@ package org.apache.spark.storage -import org.apache.spark.network.buffer.{ManagedBuffer, NettyManagedBuffer} +import java.io.InputStream +import java.nio.ByteBuffer +import java.util.concurrent.atomic.AtomicInteger + +import org.apache.spark.network.buffer.ManagedBuffer import org.apache.spark.util.io.ChunkedByteBuffer /** - * This [[ManagedBuffer]] wraps a [[ChunkedByteBuffer]] retrieved from the [[BlockManager]] + * This [[ManagedBuffer]] wraps a [[BlockData]] instance retrieved from the [[BlockManager]] * so that the corresponding block's read lock can be released once this buffer's references * are released. * + * If `dispose` is set to true, the [[BlockData]]will be disposed when the buffer's reference + * count drops to zero. + * * This is effectively a wrapper / bridge to connect the BlockManager's notion of read locks * to the network layer's notion of retain / release counts. */ private[storage] class BlockManagerManagedBuffer( blockInfoManager: BlockInfoManager, blockId: BlockId, - chunkedBuffer: ChunkedByteBuffer) extends NettyManagedBuffer(chunkedBuffer.toNetty) { + data: BlockData, + dispose: Boolean) extends ManagedBuffer { + + private val refCount = new AtomicInteger(1) + + override def size(): Long = data.size + + override def nioByteBuffer(): ByteBuffer = data.toByteBuffer() + + override def createInputStream(): InputStream = data.toInputStream() + + override def convertToNetty(): Object = data.toNetty() override def retain(): ManagedBuffer = { - super.retain() + refCount.incrementAndGet() val locked = blockInfoManager.lockForReading(blockId, blocking = false) assert(locked.isDefined) this - } + } override def release(): ManagedBuffer = { blockInfoManager.unlock(blockId) - super.release() + if (refCount.decrementAndGet() == 0 && dispose) { + data.dispose() + } + this } } diff --git a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala index ca23e2391ed0..c6656341fcd1 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala @@ -17,48 +17,67 @@ package org.apache.spark.storage -import java.io.{FileOutputStream, IOException, RandomAccessFile} +import java.io._ import java.nio.ByteBuffer +import java.nio.channels.{Channels, ReadableByteChannel, WritableByteChannel} import java.nio.channels.FileChannel.MapMode +import java.nio.charset.StandardCharsets.UTF_8 +import java.util.concurrent.ConcurrentHashMap -import com.google.common.io.Closeables +import scala.collection.mutable.ListBuffer -import org.apache.spark.SparkConf +import com.google.common.io.{ByteStreams, Closeables, Files} +import io.netty.channel.FileRegion +import io.netty.util.AbstractReferenceCounted + +import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.internal.Logging -import org.apache.spark.util.Utils +import org.apache.spark.network.buffer.ManagedBuffer +import org.apache.spark.network.util.JavaUtils +import org.apache.spark.security.CryptoStreamUtils +import org.apache.spark.util.{ByteBufferInputStream, Utils} import org.apache.spark.util.io.ChunkedByteBuffer /** * Stores BlockManager blocks on disk. */ -private[spark] class DiskStore(conf: SparkConf, diskManager: DiskBlockManager) extends Logging { +private[spark] class DiskStore( + conf: SparkConf, + diskManager: DiskBlockManager, + securityManager: SecurityManager) extends Logging { private val minMemoryMapBytes = conf.getSizeAsBytes("spark.storage.memoryMapThreshold", "2m") + private val blockSizes = new ConcurrentHashMap[String, Long]() - def getSize(blockId: BlockId): Long = { - diskManager.getFile(blockId.name).length - } + def getSize(blockId: BlockId): Long = blockSizes.get(blockId.name) /** * Invokes the provided callback function to write the specific block. * * @throws IllegalStateException if the block already exists in the disk store. */ - def put(blockId: BlockId)(writeFunc: FileOutputStream => Unit): Unit = { + def put(blockId: BlockId)(writeFunc: WritableByteChannel => Unit): Unit = { if (contains(blockId)) { throw new IllegalStateException(s"Block $blockId is already present in the disk store") } logDebug(s"Attempting to put block $blockId") val startTime = System.currentTimeMillis val file = diskManager.getFile(blockId) - val fileOutputStream = new FileOutputStream(file) + val out = new CountingWritableChannel(openForWrite(file)) var threwException: Boolean = true try { - writeFunc(fileOutputStream) + writeFunc(out) + blockSizes.put(blockId.name, out.getCount) threwException = false } finally { try { - Closeables.close(fileOutputStream, threwException) + out.close() + } catch { + case ioe: IOException => + if (!threwException) { + threwException = true + throw ioe + } } finally { if (threwException) { remove(blockId) @@ -73,41 +92,46 @@ private[spark] class DiskStore(conf: SparkConf, diskManager: DiskBlockManager) e } def putBytes(blockId: BlockId, bytes: ChunkedByteBuffer): Unit = { - put(blockId) { fileOutputStream => - val channel = fileOutputStream.getChannel - Utils.tryWithSafeFinally { - bytes.writeFully(channel) - } { - channel.close() - } + put(blockId) { channel => + bytes.writeFully(channel) } } - def getBytes(blockId: BlockId): ChunkedByteBuffer = { + def getBytes(blockId: BlockId): BlockData = { val file = diskManager.getFile(blockId.name) - val channel = new RandomAccessFile(file, "r").getChannel - Utils.tryWithSafeFinally { - // For small files, directly read rather than memory map - if (file.length < minMemoryMapBytes) { - val buf = ByteBuffer.allocate(file.length.toInt) - channel.position(0) - while (buf.remaining() != 0) { - if (channel.read(buf) == -1) { - throw new IOException("Reached EOF before filling buffer\n" + - s"offset=0\nfile=${file.getAbsolutePath}\nbuf.remaining=${buf.remaining}") + val blockSize = getSize(blockId) + + securityManager.getIOEncryptionKey() match { + case Some(key) => + // Encrypted blocks cannot be memory mapped; return a special object that does decryption + // and provides InputStream / FileRegion implementations for reading the data. + new EncryptedBlockData(file, blockSize, conf, key) + + case _ => + val channel = new FileInputStream(file).getChannel() + if (blockSize < minMemoryMapBytes) { + // For small files, directly read rather than memory map. + Utils.tryWithSafeFinally { + val buf = ByteBuffer.allocate(blockSize.toInt) + JavaUtils.readFully(channel, buf) + buf.flip() + new ByteBufferBlockData(new ChunkedByteBuffer(buf), true) + } { + channel.close() + } + } else { + Utils.tryWithSafeFinally { + new ByteBufferBlockData( + new ChunkedByteBuffer(channel.map(MapMode.READ_ONLY, 0, file.length)), true) + } { + channel.close() } } - buf.flip() - new ChunkedByteBuffer(buf) - } else { - new ChunkedByteBuffer(channel.map(MapMode.READ_ONLY, 0, file.length)) - } - } { - channel.close() } } def remove(blockId: BlockId): Boolean = { + blockSizes.remove(blockId.name) val file = diskManager.getFile(blockId.name) if (file.exists()) { val ret = file.delete() @@ -124,4 +148,142 @@ private[spark] class DiskStore(conf: SparkConf, diskManager: DiskBlockManager) e val file = diskManager.getFile(blockId.name) file.exists() } + + private def openForWrite(file: File): WritableByteChannel = { + val out = new FileOutputStream(file).getChannel() + try { + securityManager.getIOEncryptionKey().map { key => + CryptoStreamUtils.createWritableChannel(out, conf, key) + }.getOrElse(out) + } catch { + case e: Exception => + Closeables.close(out, true) + file.delete() + throw e + } + } + +} + +private class EncryptedBlockData( + file: File, + blockSize: Long, + conf: SparkConf, + key: Array[Byte]) extends BlockData { + + override def toInputStream(): InputStream = Channels.newInputStream(open()) + + override def toNetty(): Object = new ReadableChannelFileRegion(open(), blockSize) + + override def toChunkedByteBuffer(allocator: Int => ByteBuffer): ChunkedByteBuffer = { + val source = open() + try { + var remaining = blockSize + val chunks = new ListBuffer[ByteBuffer]() + while (remaining > 0) { + val chunkSize = math.min(remaining, Int.MaxValue) + val chunk = allocator(chunkSize.toInt) + remaining -= chunkSize + JavaUtils.readFully(source, chunk) + chunk.flip() + chunks += chunk + } + + new ChunkedByteBuffer(chunks.toArray) + } finally { + source.close() + } + } + + override def toByteBuffer(): ByteBuffer = { + // This is used by the block transfer service to replicate blocks. The upload code reads + // all bytes into memory to send the block to the remote executor, so it's ok to do this + // as long as the block fits in a Java array. + assert(blockSize <= Int.MaxValue, "Block is too large to be wrapped in a byte buffer.") + val dst = ByteBuffer.allocate(blockSize.toInt) + val in = open() + try { + JavaUtils.readFully(in, dst) + dst.flip() + dst + } finally { + Closeables.close(in, true) + } + } + + override def size: Long = blockSize + + override def dispose(): Unit = { } + + private def open(): ReadableByteChannel = { + val channel = new FileInputStream(file).getChannel() + try { + CryptoStreamUtils.createReadableChannel(channel, conf, key) + } catch { + case e: Exception => + Closeables.close(channel, true) + throw e + } + } + +} + +private class ReadableChannelFileRegion(source: ReadableByteChannel, blockSize: Long) + extends AbstractReferenceCounted with FileRegion { + + private var _transferred = 0L + + private val buffer = ByteBuffer.allocateDirect(64 * 1024) + buffer.flip() + + override def count(): Long = blockSize + + override def position(): Long = 0 + + override def transfered(): Long = _transferred + + override def transferTo(target: WritableByteChannel, pos: Long): Long = { + assert(pos == transfered(), "Invalid position.") + + var written = 0L + var lastWrite = -1L + while (lastWrite != 0) { + if (!buffer.hasRemaining()) { + buffer.clear() + source.read(buffer) + buffer.flip() + } + if (buffer.hasRemaining()) { + lastWrite = target.write(buffer) + written += lastWrite + } else { + lastWrite = 0 + } + } + + _transferred += written + written + } + + override def deallocate(): Unit = source.close() +} + +private class CountingWritableChannel(sink: WritableByteChannel) extends WritableByteChannel { + + private var count = 0L + + def getCount: Long = count + + override def write(src: ByteBuffer): Int = { + val written = sink.write(src) + if (written > 0) { + count += written + } + written + } + + override def isOpen(): Boolean = sink.isOpen() + + override def close(): Unit = sink.close() + } diff --git a/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala b/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala index 5efdd23f79a2..241aacd74b58 100644 --- a/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala +++ b/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala @@ -236,14 +236,6 @@ class StorageStatus(val blockManagerId: BlockManagerId, val maxMem: Long) { /** Helper methods for storage-related objects. */ private[spark] object StorageUtils extends Logging { - // Ewwww... Reflection!!! See the unmap method for justification - private val memoryMappedBufferFileDescriptorField = { - val mappedBufferClass = classOf[java.nio.MappedByteBuffer] - val fdField = mappedBufferClass.getDeclaredField("fd") - fdField.setAccessible(true) - fdField - } - /** * Attempt to clean up a ByteBuffer if it is direct or memory-mapped. This uses an *unsafe* Sun * API that will cause errors if one attempts to read from the disposed buffer. However, neither @@ -251,8 +243,6 @@ private[spark] object StorageUtils extends Logging { * pressure on the garbage collector. Waiting for garbage collection may lead to the depletion of * off-heap memory or huge numbers of open files. There's unfortunately no standard API to * manually dispose of these kinds of buffers. - * - * See also [[unmap]] */ def dispose(buffer: ByteBuffer): Unit = { if (buffer != null && buffer.isInstanceOf[MappedByteBuffer]) { @@ -261,28 +251,6 @@ private[spark] object StorageUtils extends Logging { } } - /** - * Attempt to unmap a ByteBuffer if it is memory-mapped. This uses an *unsafe* Sun API that will - * cause errors if one attempts to read from the unmapped buffer. However, the file descriptors of - * memory-mapped buffers do not put pressure on the garbage collector. Waiting for garbage - * collection may lead to huge numbers of open files. There's unfortunately no standard API to - * manually unmap memory-mapped buffers. - * - * See also [[dispose]] - */ - def unmap(buffer: ByteBuffer): Unit = { - if (buffer != null && buffer.isInstanceOf[MappedByteBuffer]) { - // Note that direct buffers are instances of MappedByteBuffer. As things stand in Java 8, the - // JDK does not provide a public API to distinguish between direct buffers and memory-mapped - // buffers. As an alternative, we peek beneath the curtains and look for a non-null file - // descriptor in mappedByteBuffer - if (memoryMappedBufferFileDescriptorField.get(buffer) != null) { - logTrace(s"Unmapping $buffer") - cleanDirectBuffer(buffer.asInstanceOf[DirectBuffer]) - } - } - } - private def cleanDirectBuffer(buffer: DirectBuffer) = { val cleaner = buffer.cleaner() if (cleaner != null) { diff --git a/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala index fb54dd66a39a..90e3af2d0ec7 100644 --- a/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala @@ -344,7 +344,7 @@ private[spark] class MemoryStore( val serializationStream: SerializationStream = { val autoPick = !blockId.isInstanceOf[StreamBlockId] val ser = serializerManager.getSerializer(classTag, autoPick).newInstance() - ser.serializeStream(serializerManager.wrapStream(blockId, redirectableStream)) + ser.serializeStream(serializerManager.wrapForCompression(blockId, redirectableStream)) } // Request enough memory to begin unrolling diff --git a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala index 1667516663b3..2f905c8af0f6 100644 --- a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala +++ b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala @@ -138,8 +138,6 @@ private[spark] class ChunkedByteBuffer(var chunks: Array[ByteBuffer]) { /** * Attempt to clean up any ByteBuffer in this ChunkedByteBuffer which is direct or memory-mapped. * See [[StorageUtils.dispose]] for more information. - * - * See also [[unmap]] */ def dispose(): Unit = { if (!disposed) { @@ -148,18 +146,6 @@ private[spark] class ChunkedByteBuffer(var chunks: Array[ByteBuffer]) { } } - /** - * Attempt to unmap any ByteBuffer in this ChunkedByteBuffer if it is memory-mapped. See - * [[StorageUtils.unmap]] for more information. - * - * See also [[dispose]] - */ - def unmap(): Unit = { - if (!disposed) { - chunks.foreach(StorageUtils.unmap) - disposed = true - } - } } /** diff --git a/core/src/test/scala/org/apache/spark/DistributedSuite.scala b/core/src/test/scala/org/apache/spark/DistributedSuite.scala index 4e36adc8baf3..84f7f1fc8eb0 100644 --- a/core/src/test/scala/org/apache/spark/DistributedSuite.scala +++ b/core/src/test/scala/org/apache/spark/DistributedSuite.scala @@ -21,6 +21,7 @@ import org.scalatest.concurrent.Timeouts._ import org.scalatest.Matchers import org.scalatest.time.{Millis, Span} +import org.apache.spark.security.EncryptionFunSuite import org.apache.spark.storage.{RDDBlockId, StorageLevel} import org.apache.spark.util.io.ChunkedByteBuffer @@ -28,7 +29,8 @@ class NotSerializableClass class NotSerializableExn(val notSer: NotSerializableClass) extends Throwable() {} -class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContext { +class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContext + with EncryptionFunSuite { val clusterUrl = "local-cluster[2,1,1024]" @@ -149,8 +151,8 @@ class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContex sc.parallelize(1 to 10).count() } - private def testCaching(storageLevel: StorageLevel): Unit = { - sc = new SparkContext(clusterUrl, "test") + private def testCaching(conf: SparkConf, storageLevel: StorageLevel): Unit = { + sc = new SparkContext(conf.setMaster(clusterUrl).setAppName("test")) sc.jobProgressListener.waitUntilExecutorsUp(2, 30000) val data = sc.parallelize(1 to 1000, 10) val cachedData = data.persist(storageLevel) @@ -187,8 +189,8 @@ class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContex "caching in memory and disk, replicated" -> StorageLevel.MEMORY_AND_DISK_2, "caching in memory and disk, serialized, replicated" -> StorageLevel.MEMORY_AND_DISK_SER_2 ).foreach { case (testName, storageLevel) => - test(testName) { - testCaching(storageLevel) + encryptionTest(testName) { conf => + testCaching(conf, storageLevel) } } diff --git a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala index 6646068d5080..82760fe92f76 100644 --- a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala +++ b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala @@ -24,8 +24,10 @@ import org.scalatest.Assertions import org.apache.spark._ import org.apache.spark.io.SnappyCompressionCodec import org.apache.spark.rdd.RDD +import org.apache.spark.security.EncryptionFunSuite import org.apache.spark.serializer.JavaSerializer import org.apache.spark.storage._ +import org.apache.spark.util.io.ChunkedByteBuffer // Dummy class that creates a broadcast variable but doesn't use it class DummyBroadcastClass(rdd: RDD[Int]) extends Serializable { @@ -43,7 +45,7 @@ class DummyBroadcastClass(rdd: RDD[Int]) extends Serializable { } } -class BroadcastSuite extends SparkFunSuite with LocalSparkContext { +class BroadcastSuite extends SparkFunSuite with LocalSparkContext with EncryptionFunSuite { test("Using TorrentBroadcast locally") { sc = new SparkContext("local", "test") @@ -61,9 +63,8 @@ class BroadcastSuite extends SparkFunSuite with LocalSparkContext { assert(results.collect().toSet === (1 to 10).map(x => (x, 10)).toSet) } - test("Accessing TorrentBroadcast variables in a local cluster") { + encryptionTest("Accessing TorrentBroadcast variables in a local cluster") { conf => val numSlaves = 4 - val conf = new SparkConf conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") conf.set("spark.broadcast.compress", "true") sc = new SparkContext("local-cluster[%d, 1, 1024]".format(numSlaves), "test", conf) @@ -85,7 +86,9 @@ class BroadcastSuite extends SparkFunSuite with LocalSparkContext { val size = 1 + rand.nextInt(1024 * 10) val data: Array[Byte] = new Array[Byte](size) rand.nextBytes(data) - val blocks = blockifyObject(data, blockSize, serializer, compressionCodec) + val blocks = blockifyObject(data, blockSize, serializer, compressionCodec).map { b => + new ChunkedByteBuffer(b).toInputStream(dispose = true) + } val unblockified = unBlockifyObject[Array[Byte]](blocks, serializer, compressionCodec) assert(unblockified === data) } @@ -137,9 +140,8 @@ class BroadcastSuite extends SparkFunSuite with LocalSparkContext { sc.stop() } - test("Cache broadcast to disk") { - val conf = new SparkConf() - .setMaster("local") + encryptionTest("Cache broadcast to disk") { conf => + conf.setMaster("local") .setAppName("test") .set("spark.memory.useLegacyMode", "true") .set("spark.storage.memoryFraction", "0.0") diff --git a/core/src/test/scala/org/apache/spark/security/CryptoStreamUtilsSuite.scala b/core/src/test/scala/org/apache/spark/security/CryptoStreamUtilsSuite.scala index 0f3a4a03618e..608052f5ed85 100644 --- a/core/src/test/scala/org/apache/spark/security/CryptoStreamUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/security/CryptoStreamUtilsSuite.scala @@ -16,9 +16,11 @@ */ package org.apache.spark.security -import java.io.{ByteArrayInputStream, ByteArrayOutputStream} +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, FileInputStream, FileOutputStream} +import java.nio.channels.Channels import java.nio.charset.StandardCharsets.UTF_8 -import java.util.UUID +import java.nio.file.Files +import java.util.{Arrays, Random, UUID} import com.google.common.io.ByteStreams @@ -121,6 +123,46 @@ class CryptoStreamUtilsSuite extends SparkFunSuite { } } + test("crypto stream wrappers") { + val testData = new Array[Byte](128 * 1024) + new Random().nextBytes(testData) + + val conf = createConf() + val key = createKey(conf) + val file = Files.createTempFile("crypto", ".test").toFile() + + val outStream = createCryptoOutputStream(new FileOutputStream(file), conf, key) + try { + ByteStreams.copy(new ByteArrayInputStream(testData), outStream) + } finally { + outStream.close() + } + + val inStream = createCryptoInputStream(new FileInputStream(file), conf, key) + try { + val inStreamData = ByteStreams.toByteArray(inStream) + assert(Arrays.equals(inStreamData, testData)) + } finally { + inStream.close() + } + + val outChannel = createWritableChannel(new FileOutputStream(file).getChannel(), conf, key) + try { + val inByteChannel = Channels.newChannel(new ByteArrayInputStream(testData)) + ByteStreams.copy(inByteChannel, outChannel) + } finally { + outChannel.close() + } + + val inChannel = createReadableChannel(new FileInputStream(file).getChannel(), conf, key) + try { + val inChannelData = ByteStreams.toByteArray(Channels.newInputStream(inChannel)) + assert(Arrays.equals(inChannelData, testData)) + } finally { + inChannel.close() + } + } + private def createConf(extra: (String, String)*): SparkConf = { val conf = new SparkConf() extra.foreach { case (k, v) => conf.set(k, v) } diff --git a/core/src/test/scala/org/apache/spark/security/EncryptionFunSuite.scala b/core/src/test/scala/org/apache/spark/security/EncryptionFunSuite.scala new file mode 100644 index 000000000000..3f52dc41abf6 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/security/EncryptionFunSuite.scala @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.security + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.internal.config._ + +trait EncryptionFunSuite { + + this: SparkFunSuite => + + /** + * Runs a test twice, initializing a SparkConf object with encryption off, then on. It's ok + * for the test to modify the provided SparkConf. + */ + final protected def encryptionTest(name: String)(fn: SparkConf => Unit) { + Seq(false, true).foreach { encrypt => + test(s"$name (encryption = ${ if (encrypt) "on" else "off" })") { + val conf = new SparkConf().set(IO_ENCRYPTION_ENABLED, encrypt) + fn(conf) + } + } + } + +} diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 64a67b4c4cba..a8b960489983 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -35,6 +35,7 @@ import org.scalatest.concurrent.Timeouts._ import org.apache.spark._ import org.apache.spark.broadcast.BroadcastManager import org.apache.spark.executor.DataReadMethod +import org.apache.spark.internal.config._ import org.apache.spark.memory.UnifiedMemoryManager import org.apache.spark.network.{BlockDataManager, BlockTransferService} import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} @@ -42,6 +43,7 @@ import org.apache.spark.network.netty.NettyBlockTransferService import org.apache.spark.network.shuffle.BlockFetchingListener import org.apache.spark.rpc.RpcEnv import org.apache.spark.scheduler.LiveListenerBus +import org.apache.spark.security.{CryptoStreamUtils, EncryptionFunSuite} import org.apache.spark.serializer.{JavaSerializer, KryoSerializer, SerializerManager} import org.apache.spark.shuffle.sort.SortShuffleManager import org.apache.spark.storage.BlockManagerMessages.BlockManagerHeartbeat @@ -49,7 +51,8 @@ import org.apache.spark.util._ import org.apache.spark.util.io.ChunkedByteBuffer class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterEach - with PrivateMethodTester with LocalSparkContext with ResetSystemProperties { + with PrivateMethodTester with LocalSparkContext with ResetSystemProperties + with EncryptionFunSuite { import BlockManagerSuite._ @@ -75,16 +78,24 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE maxMem: Long, name: String = SparkContext.DRIVER_IDENTIFIER, master: BlockManagerMaster = this.master, - transferService: Option[BlockTransferService] = Option.empty): BlockManager = { - conf.set("spark.testing.memory", maxMem.toString) - conf.set("spark.memory.offHeap.size", maxMem.toString) - val serializer = new KryoSerializer(conf) + transferService: Option[BlockTransferService] = Option.empty, + testConf: Option[SparkConf] = None): BlockManager = { + val bmConf = testConf.map(_.setAll(conf.getAll)).getOrElse(conf) + bmConf.set("spark.testing.memory", maxMem.toString) + bmConf.set("spark.memory.offHeap.size", maxMem.toString) + val serializer = new KryoSerializer(bmConf) + val encryptionKey = if (bmConf.get(IO_ENCRYPTION_ENABLED)) { + Some(CryptoStreamUtils.createKey(bmConf)) + } else { + None + } + val bmSecurityMgr = new SecurityManager(bmConf, encryptionKey) val transfer = transferService .getOrElse(new NettyBlockTransferService(conf, securityMgr, "localhost", "localhost", 0, 1)) - val memManager = UnifiedMemoryManager(conf, numCores = 1) - val serializerManager = new SerializerManager(serializer, conf) - val blockManager = new BlockManager(name, rpcEnv, master, serializerManager, conf, - memManager, mapOutputTracker, shuffleManager, transfer, securityMgr, 0) + val memManager = UnifiedMemoryManager(bmConf, numCores = 1) + val serializerManager = new SerializerManager(serializer, bmConf) + val blockManager = new BlockManager(name, rpcEnv, master, serializerManager, bmConf, + memManager, mapOutputTracker, shuffleManager, transfer, bmSecurityMgr, 0) memManager.setMemoryStore(blockManager.memoryStore) blockManager.initialize("app-id") blockManager @@ -610,8 +621,8 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE assert(store.memoryStore.contains(rdd(0, 3)), "rdd_0_3 was not in store") } - test("on-disk storage") { - store = makeBlockManager(1200) + encryptionTest("on-disk storage") { _conf => + store = makeBlockManager(1200, testConf = Some(_conf)) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -623,34 +634,35 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE assert(store.getSingleAndReleaseLock("a1").isDefined, "a1 was in store") } - test("disk and memory storage") { - testDiskAndMemoryStorage(StorageLevel.MEMORY_AND_DISK, getAsBytes = false) + encryptionTest("disk and memory storage") { _conf => + testDiskAndMemoryStorage(StorageLevel.MEMORY_AND_DISK, getAsBytes = false, testConf = conf) } - test("disk and memory storage with getLocalBytes") { - testDiskAndMemoryStorage(StorageLevel.MEMORY_AND_DISK, getAsBytes = true) + encryptionTest("disk and memory storage with getLocalBytes") { _conf => + testDiskAndMemoryStorage(StorageLevel.MEMORY_AND_DISK, getAsBytes = true, testConf = conf) } - test("disk and memory storage with serialization") { - testDiskAndMemoryStorage(StorageLevel.MEMORY_AND_DISK_SER, getAsBytes = false) + encryptionTest("disk and memory storage with serialization") { _conf => + testDiskAndMemoryStorage(StorageLevel.MEMORY_AND_DISK_SER, getAsBytes = false, testConf = conf) } - test("disk and memory storage with serialization and getLocalBytes") { - testDiskAndMemoryStorage(StorageLevel.MEMORY_AND_DISK_SER, getAsBytes = true) + encryptionTest("disk and memory storage with serialization and getLocalBytes") { _conf => + testDiskAndMemoryStorage(StorageLevel.MEMORY_AND_DISK_SER, getAsBytes = true, testConf = conf) } - test("disk and off-heap memory storage") { - testDiskAndMemoryStorage(StorageLevel.OFF_HEAP, getAsBytes = false) + encryptionTest("disk and off-heap memory storage") { _conf => + testDiskAndMemoryStorage(StorageLevel.OFF_HEAP, getAsBytes = false, testConf = conf) } - test("disk and off-heap memory storage with getLocalBytes") { - testDiskAndMemoryStorage(StorageLevel.OFF_HEAP, getAsBytes = true) + encryptionTest("disk and off-heap memory storage with getLocalBytes") { _conf => + testDiskAndMemoryStorage(StorageLevel.OFF_HEAP, getAsBytes = true, testConf = conf) } def testDiskAndMemoryStorage( storageLevel: StorageLevel, - getAsBytes: Boolean): Unit = { - store = makeBlockManager(12000) + getAsBytes: Boolean, + testConf: SparkConf): Unit = { + store = makeBlockManager(12000, testConf = Some(testConf)) val accessMethod = if (getAsBytes) store.getLocalBytesAndReleaseLock else store.getSingleAndReleaseLock val a1 = new Array[Byte](4000) @@ -678,8 +690,8 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE } } - test("LRU with mixed storage levels") { - store = makeBlockManager(12000) + encryptionTest("LRU with mixed storage levels") { _conf => + store = makeBlockManager(12000, testConf = Some(_conf)) val a1 = new Array[Byte](4000) val a2 = new Array[Byte](4000) val a3 = new Array[Byte](4000) @@ -700,8 +712,8 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE assert(store.getSingleAndReleaseLock("a4").isDefined, "a4 was not in store") } - test("in-memory LRU with streams") { - store = makeBlockManager(12000) + encryptionTest("in-memory LRU with streams") { _conf => + store = makeBlockManager(12000, testConf = Some(_conf)) val list1 = List(new Array[Byte](2000), new Array[Byte](2000)) val list2 = List(new Array[Byte](2000), new Array[Byte](2000)) val list3 = List(new Array[Byte](2000), new Array[Byte](2000)) @@ -728,8 +740,8 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE assert(store.getAndReleaseLock("list3") === None, "list1 was in store") } - test("LRU with mixed storage levels and streams") { - store = makeBlockManager(12000) + encryptionTest("LRU with mixed storage levels and streams") { _conf => + store = makeBlockManager(12000, testConf = Some(_conf)) val list1 = List(new Array[Byte](2000), new Array[Byte](2000)) val list2 = List(new Array[Byte](2000), new Array[Byte](2000)) val list3 = List(new Array[Byte](2000), new Array[Byte](2000)) @@ -1325,7 +1337,8 @@ private object BlockManagerSuite { val getAndReleaseLock: (BlockId) => Option[BlockResult] = wrapGet(store.get) val getSingleAndReleaseLock: (BlockId) => Option[Any] = wrapGet(store.getSingle) val getLocalBytesAndReleaseLock: (BlockId) => Option[ChunkedByteBuffer] = { - wrapGet(store.getLocalBytes) + val allocator = ByteBuffer.allocate _ + wrapGet { bid => store.getLocalBytes(bid).map(_.toChunkedByteBuffer(allocator)) } } } diff --git a/core/src/test/scala/org/apache/spark/storage/DiskStoreSuite.scala b/core/src/test/scala/org/apache/spark/storage/DiskStoreSuite.scala index 9e6b02b9eac4..67fc084e8a13 100644 --- a/core/src/test/scala/org/apache/spark/storage/DiskStoreSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/DiskStoreSuite.scala @@ -18,15 +18,23 @@ package org.apache.spark.storage import java.nio.{ByteBuffer, MappedByteBuffer} -import java.util.Arrays +import java.util.{Arrays, Random} -import org.apache.spark.{SparkConf, SparkFunSuite} +import com.google.common.io.{ByteStreams, Files} +import io.netty.channel.FileRegion + +import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} +import org.apache.spark.network.util.{ByteArrayWritableChannel, JavaUtils} +import org.apache.spark.security.CryptoStreamUtils import org.apache.spark.util.io.ChunkedByteBuffer import org.apache.spark.util.Utils class DiskStoreSuite extends SparkFunSuite { test("reads of memory-mapped and non memory-mapped files are equivalent") { + val conf = new SparkConf() + val securityManager = new SecurityManager(conf) + // It will cause error when we tried to re-open the filestore and the // memory-mapped byte buffer tot he file has not been GC on Windows. assume(!Utils.isWindows) @@ -37,16 +45,18 @@ class DiskStoreSuite extends SparkFunSuite { val byteBuffer = new ChunkedByteBuffer(ByteBuffer.wrap(bytes)) val blockId = BlockId("rdd_1_2") - val diskBlockManager = new DiskBlockManager(new SparkConf(), deleteFilesOnStop = true) + val diskBlockManager = new DiskBlockManager(conf, deleteFilesOnStop = true) - val diskStoreMapped = new DiskStore(new SparkConf().set(confKey, "0"), diskBlockManager) + val diskStoreMapped = new DiskStore(conf.clone().set(confKey, "0"), diskBlockManager, + securityManager) diskStoreMapped.putBytes(blockId, byteBuffer) - val mapped = diskStoreMapped.getBytes(blockId) + val mapped = diskStoreMapped.getBytes(blockId).asInstanceOf[ByteBufferBlockData].buffer assert(diskStoreMapped.remove(blockId)) - val diskStoreNotMapped = new DiskStore(new SparkConf().set(confKey, "1m"), diskBlockManager) + val diskStoreNotMapped = new DiskStore(conf.clone().set(confKey, "1m"), diskBlockManager, + securityManager) diskStoreNotMapped.putBytes(blockId, byteBuffer) - val notMapped = diskStoreNotMapped.getBytes(blockId) + val notMapped = diskStoreNotMapped.getBytes(blockId).asInstanceOf[ByteBufferBlockData].buffer // Not possible to do isInstanceOf due to visibility of HeapByteBuffer assert(notMapped.getChunks().forall(_.getClass.getName.endsWith("HeapByteBuffer")), @@ -63,4 +73,95 @@ class DiskStoreSuite extends SparkFunSuite { assert(Arrays.equals(mapped.toArray, bytes)) assert(Arrays.equals(notMapped.toArray, bytes)) } + + test("block size tracking") { + val conf = new SparkConf() + val diskBlockManager = new DiskBlockManager(conf, deleteFilesOnStop = true) + val diskStore = new DiskStore(conf, diskBlockManager, new SecurityManager(conf)) + + val blockId = BlockId("rdd_1_2") + diskStore.put(blockId) { chan => + val buf = ByteBuffer.wrap(new Array[Byte](32)) + while (buf.hasRemaining()) { + chan.write(buf) + } + } + + assert(diskStore.getSize(blockId) === 32L) + diskStore.remove(blockId) + assert(diskStore.getSize(blockId) === 0L) + } + + test("block data encryption") { + val testDir = Utils.createTempDir() + val testData = new Array[Byte](128 * 1024) + new Random().nextBytes(testData) + + val conf = new SparkConf() + val securityManager = new SecurityManager(conf, Some(CryptoStreamUtils.createKey(conf))) + val diskBlockManager = new DiskBlockManager(conf, deleteFilesOnStop = true) + val diskStore = new DiskStore(conf, diskBlockManager, securityManager) + + val blockId = BlockId("rdd_1_2") + diskStore.put(blockId) { chan => + val buf = ByteBuffer.wrap(testData) + while (buf.hasRemaining()) { + chan.write(buf) + } + } + + assert(diskStore.getSize(blockId) === testData.length) + + val diskData = Files.toByteArray(diskBlockManager.getFile(blockId.name)) + assert(!Arrays.equals(testData, diskData)) + + val blockData = diskStore.getBytes(blockId) + assert(blockData.isInstanceOf[EncryptedBlockData]) + assert(blockData.size === testData.length) + Map( + "input stream" -> readViaInputStream _, + "chunked byte buffer" -> readViaChunkedByteBuffer _, + "nio byte buffer" -> readViaNioBuffer _, + "managed buffer" -> readViaManagedBuffer _ + ).foreach { case (name, fn) => + val readData = fn(blockData) + assert(readData.length === blockData.size, s"Size of data read via $name did not match.") + assert(Arrays.equals(testData, readData), s"Data read via $name did not match.") + } + } + + private def readViaInputStream(data: BlockData): Array[Byte] = { + val is = data.toInputStream() + try { + ByteStreams.toByteArray(is) + } finally { + is.close() + } + } + + private def readViaChunkedByteBuffer(data: BlockData): Array[Byte] = { + val buf = data.toChunkedByteBuffer(ByteBuffer.allocate _) + try { + buf.toArray + } finally { + buf.dispose() + } + } + + private def readViaNioBuffer(data: BlockData): Array[Byte] = { + JavaUtils.bufferToArray(data.toByteBuffer()) + } + + private def readViaManagedBuffer(data: BlockData): Array[Byte] = { + val region = data.toNetty().asInstanceOf[FileRegion] + val byteChannel = new ByteArrayWritableChannel(data.size.toInt) + + while (region.transfered() < region.count()) { + region.transferTo(byteChannel, region.transfered()) + } + + byteChannel.close() + byteChannel.getData + } + } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala b/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala index d0864fd3678b..844760ab61d2 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala @@ -158,16 +158,14 @@ class WriteAheadLogBackedBlockRDD[T: ClassTag]( logInfo(s"Read partition data of $this from write ahead log, record handle " + partition.walRecordHandle) if (storeInBlockManager) { - blockManager.putBytes(blockId, new ChunkedByteBuffer(dataRead.duplicate()), storageLevel, - encrypt = true) + blockManager.putBytes(blockId, new ChunkedByteBuffer(dataRead.duplicate()), storageLevel) logDebug(s"Stored partition data of $this into block manager with level $storageLevel") dataRead.rewind() } serializerManager .dataDeserializeStream( blockId, - new ChunkedByteBuffer(dataRead).toInputStream(), - maybeEncrypted = false)(elementClassTag) + new ChunkedByteBuffer(dataRead).toInputStream())(elementClassTag) .asInstanceOf[Iterator[T]] } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala index 2b488038f062..80c07958b41f 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala @@ -87,8 +87,7 @@ private[streaming] class BlockManagerBasedBlockHandler( putResult case ByteBufferBlock(byteBuffer) => blockManager.putBytes( - blockId, new ChunkedByteBuffer(byteBuffer.duplicate()), storageLevel, tellMaster = true, - encrypt = true) + blockId, new ChunkedByteBuffer(byteBuffer.duplicate()), storageLevel, tellMaster = true) case o => throw new SparkException( s"Could not store $blockId to block manager, unexpected block type ${o.getClass.getName}") @@ -176,11 +175,10 @@ private[streaming] class WriteAheadLogBasedBlockHandler( val serializedBlock = block match { case ArrayBufferBlock(arrayBuffer) => numRecords = Some(arrayBuffer.size.toLong) - serializerManager.dataSerialize(blockId, arrayBuffer.iterator, allowEncryption = false) + serializerManager.dataSerialize(blockId, arrayBuffer.iterator) case IteratorBlock(iterator) => val countIterator = new CountingIterator(iterator) - val serializedBlock = serializerManager.dataSerialize(blockId, countIterator, - allowEncryption = false) + val serializedBlock = serializerManager.dataSerialize(blockId, countIterator) numRecords = countIterator.count serializedBlock case ByteBufferBlock(byteBuffer) => @@ -195,8 +193,7 @@ private[streaming] class WriteAheadLogBasedBlockHandler( blockId, serializedBlock, effectiveStorageLevel, - tellMaster = true, - encrypt = true) + tellMaster = true) if (!putSucceeded) { throw new SparkException( s"Could not store $blockId to block manager with storage level $storageLevel") diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala index c2b0389b8c6f..3c4a2716caf9 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala @@ -175,8 +175,7 @@ abstract class BaseReceivedBlockHandlerSuite(enableEncryption: Boolean) reader.close() serializerManager.dataDeserializeStream( generateBlockId(), - new ChunkedByteBuffer(bytes).toInputStream(), - maybeEncrypted = false)(ClassTag.Any).toList + new ChunkedByteBuffer(bytes).toInputStream())(ClassTag.Any).toList } loggedData shouldEqual data } @@ -357,7 +356,7 @@ abstract class BaseReceivedBlockHandlerSuite(enableEncryption: Boolean) } def dataToByteBuffer(b: Seq[String]) = - serializerManager.dataSerialize(generateBlockId, b.iterator, allowEncryption = false) + serializerManager.dataSerialize(generateBlockId, b.iterator) val blocks = data.grouped(10).toSeq diff --git a/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala index 2ac0dc96916c..aa69be7ca993 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala @@ -250,8 +250,7 @@ class WriteAheadLogBackedBlockRDDSuite require(blockData.size === blockIds.size) val writer = new FileBasedWriteAheadLogWriter(new File(dir, "logFile").toString, hadoopConf) val segments = blockData.zip(blockIds).map { case (data, id) => - writer.write(serializerManager.dataSerialize(id, data.iterator, allowEncryption = false) - .toByteBuffer) + writer.write(serializerManager.dataSerialize(id, data.iterator).toByteBuffer) } writer.close() segments From c622a87c44e0621e1b3024fdca9b2aa3c508615b Mon Sep 17 00:00:00 2001 From: jerryshao Date: Wed, 29 Mar 2017 10:09:58 -0700 Subject: [PATCH 159/512] [SPARK-20059][YARN] Use the correct classloader for HBaseCredentialProvider ## What changes were proposed in this pull request? Currently we use system classloader to find HBase jars, if it is specified by `--jars`, then it will be failed with ClassNotFound issue. So here changing to use child classloader. Also putting added jars and main jar into classpath of submitted application in yarn cluster mode, otherwise HBase jars specified with `--jars` will never be honored in cluster mode, and fetching tokens in client side will always be failed. ## How was this patch tested? Unit test and local verification. Author: jerryshao Closes #17388 from jerryshao/SPARK-20059. --- .../main/scala/org/apache/spark/deploy/SparkSubmit.scala | 7 ++++++- .../scala/org/apache/spark/deploy/SparkSubmitSuite.scala | 7 ++++++- .../deploy/yarn/security/HBaseCredentialProvider.scala | 5 +++-- 3 files changed, 15 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 1e50eb663565..77005aa9040b 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -485,12 +485,17 @@ object SparkSubmit extends CommandLineUtils { // In client mode, launch the application main class directly // In addition, add the main application jar and any added jars (if any) to the classpath - if (deployMode == CLIENT) { + // Also add the main application jar and any added jars to classpath in case YARN client + // requires these jars. + if (deployMode == CLIENT || isYarnCluster) { childMainClass = args.mainClass if (isUserJar(args.primaryResource)) { childClasspath += args.primaryResource } if (args.jars != null) { childClasspath ++= args.jars.split(",") } + } + + if (deployMode == CLIENT) { if (args.childArgs != null) { childArgs ++= args.childArgs } } diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 9417930d0240..a591b98bca48 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -213,7 +213,12 @@ class SparkSubmitSuite childArgsStr should include ("--arg arg1 --arg arg2") childArgsStr should include regex ("--jar .*thejar.jar") mainClass should be ("org.apache.spark.deploy.yarn.Client") - classpath should have length (0) + + // In yarn cluster mode, also adding jars to classpath + classpath(0) should endWith ("thejar.jar") + classpath(1) should endWith ("one.jar") + classpath(2) should endWith ("two.jar") + classpath(3) should endWith ("three.jar") sysProps("spark.executor.memory") should be ("5g") sysProps("spark.driver.memory") should be ("4g") diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/HBaseCredentialProvider.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/HBaseCredentialProvider.scala index 5571df09a2ec..5adeb8e605ff 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/HBaseCredentialProvider.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/HBaseCredentialProvider.scala @@ -26,6 +26,7 @@ import org.apache.hadoop.security.token.{Token, TokenIdentifier} import org.apache.spark.SparkConf import org.apache.spark.internal.Logging +import org.apache.spark.util.Utils private[security] class HBaseCredentialProvider extends ServiceCredentialProvider with Logging { @@ -36,7 +37,7 @@ private[security] class HBaseCredentialProvider extends ServiceCredentialProvide sparkConf: SparkConf, creds: Credentials): Option[Long] = { try { - val mirror = universe.runtimeMirror(getClass.getClassLoader) + val mirror = universe.runtimeMirror(Utils.getContextOrSparkClassLoader) val obtainToken = mirror.classLoader. loadClass("org.apache.hadoop.hbase.security.token.TokenUtil"). getMethod("obtainToken", classOf[Configuration]) @@ -60,7 +61,7 @@ private[security] class HBaseCredentialProvider extends ServiceCredentialProvide private def hbaseConf(conf: Configuration): Configuration = { try { - val mirror = universe.runtimeMirror(getClass.getClassLoader) + val mirror = universe.runtimeMirror(Utils.getContextOrSparkClassLoader) val confCreate = mirror.classLoader. loadClass("org.apache.hadoop.hbase.HBaseConfiguration"). getMethod("create", classOf[Configuration]) From d6ddfdf60e77340256873b5acf08e85f95cf3bc2 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Wed, 29 Mar 2017 11:41:17 -0700 Subject: [PATCH 160/512] [SPARK-19955][PYSPARK] Jenkins Python Conda based test. ## What changes were proposed in this pull request? Allow Jenkins Python tests to use the installed conda to test Python 2.7 support & test pip installability. ## How was this patch tested? Updated shell scripts, ran tests locally with installed conda, ran tests in Jenkins. Author: Holden Karau Closes #17355 from holdenk/SPARK-19955-support-python-tests-with-conda. --- dev/run-pip-tests | 66 +++++++++++++++++++++++++++---------------- dev/run-tests-jenkins | 3 +- python/run-tests.py | 6 ++-- 3 files changed, 47 insertions(+), 28 deletions(-) diff --git a/dev/run-pip-tests b/dev/run-pip-tests index af1b1feb70cd..d51dde12a03c 100755 --- a/dev/run-pip-tests +++ b/dev/run-pip-tests @@ -35,9 +35,28 @@ function delete_virtualenv() { } trap delete_virtualenv EXIT +PYTHON_EXECS=() # Some systems don't have pip or virtualenv - in those cases our tests won't work. -if ! hash virtualenv 2>/dev/null; then - echo "Missing virtualenv skipping pip installability tests." +if hash virtualenv 2>/dev/null && [ ! -n "$USE_CONDA" ]; then + echo "virtualenv installed - using. Note if this is a conda virtual env you may wish to set USE_CONDA" + # Figure out which Python execs we should test pip installation with + if hash python2 2>/dev/null; then + # We do this since we are testing with virtualenv and the default virtual env python + # is in /usr/bin/python + PYTHON_EXECS+=('python2') + elif hash python 2>/dev/null; then + # If python2 isn't installed fallback to python if available + PYTHON_EXECS+=('python') + fi + if hash python3 2>/dev/null; then + PYTHON_EXECS+=('python3') + fi +elif hash conda 2>/dev/null; then + echo "Using conda virtual enviroments" + PYTHON_EXECS=('3.5') + USE_CONDA=1 +else + echo "Missing virtualenv & conda, skipping pip installability tests" exit 0 fi if ! hash pip 2>/dev/null; then @@ -45,22 +64,8 @@ if ! hash pip 2>/dev/null; then exit 0 fi -# Figure out which Python execs we should test pip installation with -PYTHON_EXECS=() -if hash python2 2>/dev/null; then - # We do this since we are testing with virtualenv and the default virtual env python - # is in /usr/bin/python - PYTHON_EXECS+=('python2') -elif hash python 2>/dev/null; then - # If python2 isn't installed fallback to python if available - PYTHON_EXECS+=('python') -fi -if hash python3 2>/dev/null; then - PYTHON_EXECS+=('python3') -fi - # Determine which version of PySpark we are building for archive name -PYSPARK_VERSION=$(python -c "exec(open('python/pyspark/version.py').read());print __version__") +PYSPARK_VERSION=$(python3 -c "exec(open('python/pyspark/version.py').read());print(__version__)") PYSPARK_DIST="$FWDIR/python/dist/pyspark-$PYSPARK_VERSION.tar.gz" # The pip install options we use for all the pip commands PIP_OPTIONS="--upgrade --no-cache-dir --force-reinstall " @@ -75,18 +80,24 @@ for python in "${PYTHON_EXECS[@]}"; do echo "Using $VIRTUALENV_BASE for virtualenv" VIRTUALENV_PATH="$VIRTUALENV_BASE"/$python rm -rf "$VIRTUALENV_PATH" - mkdir -p "$VIRTUALENV_PATH" - virtualenv --python=$python "$VIRTUALENV_PATH" - source "$VIRTUALENV_PATH"/bin/activate - # Upgrade pip & friends - pip install --upgrade pip pypandoc wheel - pip install numpy # Needed so we can verify mllib imports + if [ -n "$USE_CONDA" ]; then + conda create -y -p "$VIRTUALENV_PATH" python=$python numpy pandas pip setuptools + source activate "$VIRTUALENV_PATH" + else + mkdir -p "$VIRTUALENV_PATH" + virtualenv --python=$python "$VIRTUALENV_PATH" + source "$VIRTUALENV_PATH"/bin/activate + fi + # Upgrade pip & friends if using virutal env + if [ ! -n "USE_CONDA" ]; then + pip install --upgrade pip pypandoc wheel numpy + fi echo "Creating pip installable source dist" cd "$FWDIR"/python # Delete the egg info file if it exists, this can cache the setup file. rm -rf pyspark.egg-info || echo "No existing egg info file, skipping deletion" - $python setup.py sdist + python setup.py sdist echo "Installing dist into virtual env" @@ -112,6 +123,13 @@ for python in "${PYTHON_EXECS[@]}"; do cd "$FWDIR" + # conda / virtualenv enviroments need to be deactivated differently + if [ -n "$USE_CONDA" ]; then + source deactivate + else + deactivate + fi + done done diff --git a/dev/run-tests-jenkins b/dev/run-tests-jenkins index e79accf9e987..f41f1ac79e38 100755 --- a/dev/run-tests-jenkins +++ b/dev/run-tests-jenkins @@ -22,7 +22,8 @@ # Environment variables are populated by the code here: #+ https://github.com/jenkinsci/ghprb-plugin/blob/master/src/main/java/org/jenkinsci/plugins/ghprb/GhprbTrigger.java#L139 -FWDIR="$(cd "`dirname $0`"/..; pwd)" +FWDIR="$( cd "$( dirname "$0" )/.." && pwd )" cd "$FWDIR" +export PATH=/home/anaconda/bin:$PATH exec python -u ./dev/run-tests-jenkins.py "$@" diff --git a/python/run-tests.py b/python/run-tests.py index 53a0aef229b0..b2e50435bb19 100755 --- a/python/run-tests.py +++ b/python/run-tests.py @@ -111,9 +111,9 @@ def run_individual_python_test(test_name, pyspark_python): def get_default_python_executables(): - python_execs = [x for x in ["python2.6", "python3.4", "pypy"] if which(x)] - if "python2.6" not in python_execs: - LOGGER.warning("Not testing against `python2.6` because it could not be found; falling" + python_execs = [x for x in ["python2.7", "python3.4", "pypy"] if which(x)] + if "python2.7" not in python_execs: + LOGGER.warning("Not testing against `python2.7` because it could not be found; falling" " back to `python` instead") python_execs.insert(0, "python") return python_execs From 142f6d14928c780cc9e8d6d7749c5d7c08a30972 Mon Sep 17 00:00:00 2001 From: Kunal Khamar Date: Wed, 29 Mar 2017 12:35:19 -0700 Subject: [PATCH 161/512] [SPARK-20048][SQL] Cloning SessionState does not clone query execution listeners ## What changes were proposed in this pull request? Bugfix from [SPARK-19540.](https://github.com/apache/spark/pull/16826) Cloning SessionState does not clone query execution listeners, so cloned session is unable to listen to events on queries. ## How was this patch tested? - Unit test Author: Kunal Khamar Closes #17379 from kunalkhamar/clone-bugfix. --- .../org/apache/spark/sql/SparkSession.scala | 22 ++++---- ...rs.scala => BaseSessionStateBuilder.scala} | 24 ++++++++- .../spark/sql/internal/SessionState.scala | 38 ++++--------- .../sql/util/QueryExecutionListener.scala | 10 ++++ .../apache/spark/sql/SessionStateSuite.scala | 53 +++++++++++++++++++ .../hive/thriftserver/SparkSQLCLIDriver.scala | 2 +- ...te.scala => HiveSessionStateBuilder.scala} | 14 +---- .../apache/spark/sql/hive/test/TestHive.scala | 2 +- .../sql/hive/HiveSessionStateSuite.scala | 2 +- 9 files changed, 111 insertions(+), 56 deletions(-) rename sql/core/src/main/scala/org/apache/spark/sql/internal/{sessionStateBuilders.scala => BaseSessionStateBuilder.scala} (92%) rename sql/hive/src/main/scala/org/apache/spark/sql/hive/{HiveSessionState.scala => HiveSessionStateBuilder.scala} (92%) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index 49562578b23c..a97297892b5e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -38,7 +38,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Range} import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.ui.SQLListener -import org.apache.spark.sql.internal.{CatalogImpl, SessionState, SharedState} +import org.apache.spark.sql.internal.{BaseSessionStateBuilder, CatalogImpl, SessionState, SessionStateBuilder, SharedState} import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.sql.streaming._ @@ -194,7 +194,7 @@ class SparkSession private( * * @since 2.0.0 */ - def udf: UDFRegistration = sessionState.udf + def udf: UDFRegistration = sessionState.udfRegistration /** * :: Experimental :: @@ -990,28 +990,28 @@ object SparkSession { /** Reference to the root SparkSession. */ private val defaultSession = new AtomicReference[SparkSession] - private val HIVE_SESSION_STATE_CLASS_NAME = "org.apache.spark.sql.hive.HiveSessionState" + private val HIVE_SESSION_STATE_BUILDER_CLASS_NAME = + "org.apache.spark.sql.hive.HiveSessionStateBuilder" private def sessionStateClassName(conf: SparkConf): String = { conf.get(CATALOG_IMPLEMENTATION) match { - case "hive" => HIVE_SESSION_STATE_CLASS_NAME - case "in-memory" => classOf[SessionState].getCanonicalName + case "hive" => HIVE_SESSION_STATE_BUILDER_CLASS_NAME + case "in-memory" => classOf[SessionStateBuilder].getCanonicalName } } /** * Helper method to create an instance of `SessionState` based on `className` from conf. - * The result is either `SessionState` or `HiveSessionState`. + * The result is either `SessionState` or a Hive based `SessionState`. */ private def instantiateSessionState( className: String, sparkSession: SparkSession): SessionState = { - try { - // get `SessionState.apply(SparkSession)` + // invoke `new [Hive]SessionStateBuilder(SparkSession, Option[SessionState])` val clazz = Utils.classForName(className) - val method = clazz.getMethod("apply", sparkSession.getClass) - method.invoke(null, sparkSession).asInstanceOf[SessionState] + val ctor = clazz.getConstructors.head + ctor.newInstance(sparkSession, None).asInstanceOf[BaseSessionStateBuilder].build() } catch { case NonFatal(e) => throw new IllegalArgumentException(s"Error while instantiating '$className':", e) @@ -1023,7 +1023,7 @@ object SparkSession { */ private[spark] def hiveClassesArePresent: Boolean = { try { - Utils.classForName(HIVE_SESSION_STATE_CLASS_NAME) + Utils.classForName(HIVE_SESSION_STATE_BUILDER_CLASS_NAME) Utils.classForName("org.apache.hadoop.hive.conf.HiveConf") true } catch { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/sessionStateBuilders.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala similarity index 92% rename from sql/core/src/main/scala/org/apache/spark/sql/internal/sessionStateBuilders.scala rename to sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala index b8f645fdee85..2b14eca919fa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/sessionStateBuilders.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.internal import org.apache.spark.SparkConf import org.apache.spark.annotation.{Experimental, InterfaceStability} -import org.apache.spark.sql.{ExperimentalMethods, SparkSession, Strategy} +import org.apache.spark.sql.{ExperimentalMethods, SparkSession, Strategy, UDFRegistration} import org.apache.spark.sql.catalyst.analysis.{Analyzer, FunctionRegistry} import org.apache.spark.sql.catalyst.catalog.SessionCatalog import org.apache.spark.sql.catalyst.optimizer.Optimizer @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.{QueryExecution, SparkOptimizer, SparkPlanner, SparkSqlParser} import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.streaming.StreamingQueryManager +import org.apache.spark.sql.util.ExecutionListenerManager /** * Builder class that coordinates construction of a new [[SessionState]]. @@ -133,6 +134,14 @@ abstract class BaseSessionStateBuilder( catalog } + /** + * Interface exposed to the user for registering user-defined functions. + * + * Note 1: The user-defined functions must be deterministic. + * Note 2: This depends on the `functionRegistry` field. + */ + protected def udfRegistration: UDFRegistration = new UDFRegistration(functionRegistry) + /** * Logical query plan analyzer for resolving unresolved attributes and relations. * @@ -232,6 +241,16 @@ abstract class BaseSessionStateBuilder( */ protected def streamingQueryManager: StreamingQueryManager = new StreamingQueryManager(session) + /** + * An interface to register custom [[org.apache.spark.sql.util.QueryExecutionListener]]s + * that listen for execution metrics. + * + * This gets cloned from parent if available, otherwise is a new instance is created. + */ + protected def listenerManager: ExecutionListenerManager = { + parentState.map(_.listenerManager.clone()).getOrElse(new ExecutionListenerManager) + } + /** * Function used to make clones of the session state. */ @@ -245,17 +264,18 @@ abstract class BaseSessionStateBuilder( */ def build(): SessionState = { new SessionState( - session.sparkContext, session.sharedState, conf, experimentalMethods, functionRegistry, + udfRegistration, catalog, sqlParser, analyzer, optimizer, planner, streamingQueryManager, + listenerManager, resourceLoader, createQueryExecution, createClone) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala index c6241d923d7b..1b341a12fc60 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala @@ -32,43 +32,46 @@ import org.apache.spark.sql.catalyst.parser.ParserInterface import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution._ import org.apache.spark.sql.streaming.StreamingQueryManager -import org.apache.spark.sql.util.ExecutionListenerManager +import org.apache.spark.sql.util.{ExecutionListenerManager, QueryExecutionListener} /** * A class that holds all session-specific state in a given [[SparkSession]]. * - * @param sparkContext The [[SparkContext]]. - * @param sharedState The shared state. + * @param sharedState The state shared across sessions, e.g. global view manager, external catalog. * @param conf SQL-specific key-value configurations. - * @param experimentalMethods The experimental methods. + * @param experimentalMethods Interface to add custom planning strategies and optimizers. * @param functionRegistry Internal catalog for managing functions registered by the user. + * @param udfRegistration Interface exposed to the user for registering user-defined functions. * @param catalog Internal catalog for managing table and database states. * @param sqlParser Parser that extracts expressions, plans, table identifiers etc. from SQL texts. * @param analyzer Logical query plan analyzer for resolving unresolved attributes and relations. * @param optimizer Logical query plan optimizer. - * @param planner Planner that converts optimized logical plans to physical plans + * @param planner Planner that converts optimized logical plans to physical plans. * @param streamingQueryManager Interface to start and stop streaming queries. + * @param listenerManager Interface to register custom [[QueryExecutionListener]]s. + * @param resourceLoader Session shared resource loader to load JARs, files, etc. * @param createQueryExecution Function used to create QueryExecution objects. * @param createClone Function used to create clones of the session state. */ private[sql] class SessionState( - sparkContext: SparkContext, sharedState: SharedState, val conf: SQLConf, val experimentalMethods: ExperimentalMethods, val functionRegistry: FunctionRegistry, + val udfRegistration: UDFRegistration, val catalog: SessionCatalog, val sqlParser: ParserInterface, val analyzer: Analyzer, val optimizer: Optimizer, val planner: SparkPlanner, val streamingQueryManager: StreamingQueryManager, + val listenerManager: ExecutionListenerManager, val resourceLoader: SessionResourceLoader, createQueryExecution: LogicalPlan => QueryExecution, createClone: (SparkSession, SessionState) => SessionState) { def newHadoopConf(): Configuration = SessionState.newHadoopConf( - sparkContext.hadoopConfiguration, + sharedState.sparkContext.hadoopConfiguration, conf) def newHadoopConfWithOptions(options: Map[String, String]): Configuration = { @@ -81,18 +84,6 @@ private[sql] class SessionState( hadoopConf } - /** - * Interface exposed to the user for registering user-defined functions. - * Note that the user-defined functions must be deterministic. - */ - val udf: UDFRegistration = new UDFRegistration(functionRegistry) - - /** - * An interface to register custom [[org.apache.spark.sql.util.QueryExecutionListener]]s - * that listen for execution metrics. - */ - val listenerManager: ExecutionListenerManager = new ExecutionListenerManager - /** * Get an identical copy of the `SessionState` and associate it with the given `SparkSession` */ @@ -110,13 +101,6 @@ private[sql] class SessionState( } private[sql] object SessionState { - /** - * Create a new [[SessionState]] for the given session. - */ - def apply(session: SparkSession): SessionState = { - new SessionStateBuilder(session).build() - } - def newHadoopConf(hadoopConf: Configuration, sqlConf: SQLConf): Configuration = { val newHadoopConf = new Configuration(hadoopConf) sqlConf.getAllConfs.foreach { case (k, v) => if (v ne null) newHadoopConf.set(k, v) } @@ -155,7 +139,7 @@ class SessionResourceLoader(session: SparkSession) extends FunctionResourceLoade /** * Add a jar path to [[SparkContext]] and the classloader. * - * Note: this method seems not access any session state, but the subclass `HiveSessionState` needs + * Note: this method seems not access any session state, but a Hive based `SessionState` needs * to add the jar to its hive client for the current session. Hence, it still needs to be in * [[SessionState]]. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala index 26ad0eadd9d4..f6240d85fba6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala @@ -98,6 +98,16 @@ class ExecutionListenerManager private[sql] () extends Logging { listeners.clear() } + /** + * Get an identical copy of this listener manager. + */ + @DeveloperApi + override def clone(): ExecutionListenerManager = writeLock { + val newListenerManager = new ExecutionListenerManager + listeners.foreach(newListenerManager.register) + newListenerManager + } + private[sql] def onSuccess(funcName: String, qe: QueryExecution, duration: Long): Unit = { readLock { withErrorHandling { listener => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala index 2d5e37242a58..5638c8eeda84 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala @@ -19,10 +19,13 @@ package org.apache.spark.sql import org.scalatest.BeforeAndAfterAll import org.scalatest.BeforeAndAfterEach +import scala.collection.mutable.ArrayBuffer import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.QueryExecution +import org.apache.spark.sql.util.QueryExecutionListener class SessionStateSuite extends SparkFunSuite with BeforeAndAfterEach with BeforeAndAfterAll { @@ -122,6 +125,56 @@ class SessionStateSuite extends SparkFunSuite } } + test("fork new session and inherit listener manager") { + class CommandCollector extends QueryExecutionListener { + val commands: ArrayBuffer[String] = ArrayBuffer.empty[String] + override def onFailure(funcName: String, qe: QueryExecution, ex: Exception) : Unit = {} + override def onSuccess(funcName: String, qe: QueryExecution, duration: Long): Unit = { + commands += funcName + } + } + val collectorA = new CommandCollector + val collectorB = new CommandCollector + val collectorC = new CommandCollector + + try { + def runCollectQueryOn(sparkSession: SparkSession): Unit = { + val tupleEncoder = Encoders.tuple(Encoders.scalaInt, Encoders.STRING) + val df = sparkSession.createDataset(Seq(1 -> "a"))(tupleEncoder).toDF("i", "j") + df.select("i").collect() + } + + activeSession.listenerManager.register(collectorA) + val forkedSession = activeSession.cloneSession() + + // inheritance + assert(forkedSession ne activeSession) + assert(forkedSession.listenerManager ne activeSession.listenerManager) + runCollectQueryOn(forkedSession) + assert(collectorA.commands.length == 1) // forked should callback to A + assert(collectorA.commands(0) == "collect") + + // independence + // => changes to forked do not affect original + forkedSession.listenerManager.register(collectorB) + runCollectQueryOn(activeSession) + assert(collectorB.commands.isEmpty) // original should not callback to B + assert(collectorA.commands.length == 2) // original should still callback to A + assert(collectorA.commands(1) == "collect") + // <= changes to original do not affect forked + activeSession.listenerManager.register(collectorC) + runCollectQueryOn(forkedSession) + assert(collectorC.commands.isEmpty) // forked should not callback to C + assert(collectorA.commands.length == 3) // forked should still callback to A + assert(collectorB.commands.length == 1) // forked should still callback to B + assert(collectorA.commands(2) == "collect") + assert(collectorB.commands(0) == "collect") + } finally { + activeSession.listenerManager.unregister(collectorA) + activeSession.listenerManager.unregister(collectorC) + } + } + test("fork new sessions and run query on inherited table") { def checkTableExists(sparkSession: SparkSession): Unit = { QueryTest.checkAnswer(sparkSession.sql( diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala index 0c79b6f4211f..390b9b6d68ca 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala @@ -38,7 +38,7 @@ import org.apache.thrift.transport.TSocket import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.hive.{HiveSessionState, HiveUtils} +import org.apache.spark.sql.hive.HiveUtils import org.apache.spark.util.ShutdownHookManager /** diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala similarity index 92% rename from sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala rename to sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala index f49e6bb41864..8048c2ba2c2e 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala @@ -28,19 +28,7 @@ import org.apache.spark.sql.hive.client.HiveClient import org.apache.spark.sql.internal.{BaseSessionStateBuilder, SessionResourceLoader, SessionState} /** - * Entry object for creating a Hive aware [[SessionState]]. - */ -private[hive] object HiveSessionState { - /** - * Create a new Hive aware [[SessionState]]. for the given session. - */ - def apply(session: SparkSession): SessionState = { - new HiveSessionStateBuilder(session).build() - } -} - -/** - * Builder that produces a [[HiveSessionState]]. + * Builder that produces a Hive aware [[SessionState]]. */ @Experimental @InterfaceStability.Unstable diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index 0bcf21992276..d9bb1f8c7edc 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -39,7 +39,7 @@ import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.execution.command.CacheTableCommand import org.apache.spark.sql.hive._ import org.apache.spark.sql.hive.client.HiveClient -import org.apache.spark.sql.internal._ +import org.apache.spark.sql.internal.{SessionState, SharedState, SQLConf, WithTestConf} import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION import org.apache.spark.util.{ShutdownHookManager, Utils} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSessionStateSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSessionStateSuite.scala index 67c77fb62f4e..958ad3e1c3ce 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSessionStateSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSessionStateSuite.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql._ import org.apache.spark.sql.hive.test.TestHiveSingleton /** - * Run all tests from `SessionStateSuite` with a `HiveSessionState`. + * Run all tests from `SessionStateSuite` with a Hive based `SessionState`. */ class HiveSessionStateSuite extends SessionStateSuite with TestHiveSingleton with BeforeAndAfterEach { From c4008480b781379ac0451b9220300d83c054c60d Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Wed, 29 Mar 2017 12:37:49 -0700 Subject: [PATCH 162/512] [SPARK-20009][SQL] Support DDL strings for defining schema in functions.from_json ## What changes were proposed in this pull request? This pr added `StructType.fromDDL` to convert a DDL format string into `StructType` for defining schemas in `functions.from_json`. ## How was this patch tested? Added tests in `JsonFunctionsSuite`. Author: Takeshi Yamamuro Closes #17406 from maropu/SPARK-20009. --- .../apache/spark/sql/types/StructType.scala | 6 ++ .../spark/sql/types/DataTypeSuite.scala | 85 ++++++++++++++----- .../org/apache/spark/sql/functions.scala | 15 +++- .../apache/spark/sql/JsonFunctionsSuite.scala | 7 ++ .../sql/sources/SimpleTextRelation.scala | 2 +- 5 files changed, 90 insertions(+), 25 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index 8d8b5b86d5aa..54006e20a3eb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -417,6 +417,12 @@ object StructType extends AbstractDataType { } } + /** + * Creates StructType for a given DDL-formatted string, which is a comma separated list of field + * definitions, e.g., a INT, b STRING. + */ + def fromDDL(ddl: String): StructType = CatalystSqlParser.parseTableSchema(ddl) + def apply(fields: Seq[StructField]): StructType = StructType(fields.toArray) def apply(fields: java.util.List[StructField]): StructType = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala index 61e1ec7c7ab3..05cb999af6a5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala @@ -169,30 +169,72 @@ class DataTypeSuite extends SparkFunSuite { assert(!arrayType.existsRecursively(_.isInstanceOf[IntegerType])) } - def checkDataTypeJsonRepr(dataType: DataType): Unit = { - test(s"JSON - $dataType") { + def checkDataTypeFromJson(dataType: DataType): Unit = { + test(s"from Json - $dataType") { assert(DataType.fromJson(dataType.json) === dataType) } } - checkDataTypeJsonRepr(NullType) - checkDataTypeJsonRepr(BooleanType) - checkDataTypeJsonRepr(ByteType) - checkDataTypeJsonRepr(ShortType) - checkDataTypeJsonRepr(IntegerType) - checkDataTypeJsonRepr(LongType) - checkDataTypeJsonRepr(FloatType) - checkDataTypeJsonRepr(DoubleType) - checkDataTypeJsonRepr(DecimalType(10, 5)) - checkDataTypeJsonRepr(DecimalType.SYSTEM_DEFAULT) - checkDataTypeJsonRepr(DateType) - checkDataTypeJsonRepr(TimestampType) - checkDataTypeJsonRepr(StringType) - checkDataTypeJsonRepr(BinaryType) - checkDataTypeJsonRepr(ArrayType(DoubleType, true)) - checkDataTypeJsonRepr(ArrayType(StringType, false)) - checkDataTypeJsonRepr(MapType(IntegerType, StringType, true)) - checkDataTypeJsonRepr(MapType(IntegerType, ArrayType(DoubleType), false)) + def checkDataTypeFromDDL(dataType: DataType): Unit = { + test(s"from DDL - $dataType") { + val parsed = StructType.fromDDL(s"a ${dataType.sql}") + val expected = new StructType().add("a", dataType) + assert(parsed.sameType(expected)) + } + } + + checkDataTypeFromJson(NullType) + + checkDataTypeFromJson(BooleanType) + checkDataTypeFromDDL(BooleanType) + + checkDataTypeFromJson(ByteType) + checkDataTypeFromDDL(ByteType) + + checkDataTypeFromJson(ShortType) + checkDataTypeFromDDL(ShortType) + + checkDataTypeFromJson(IntegerType) + checkDataTypeFromDDL(IntegerType) + + checkDataTypeFromJson(LongType) + checkDataTypeFromDDL(LongType) + + checkDataTypeFromJson(FloatType) + checkDataTypeFromDDL(FloatType) + + checkDataTypeFromJson(DoubleType) + checkDataTypeFromDDL(DoubleType) + + checkDataTypeFromJson(DecimalType(10, 5)) + checkDataTypeFromDDL(DecimalType(10, 5)) + + checkDataTypeFromJson(DecimalType.SYSTEM_DEFAULT) + checkDataTypeFromDDL(DecimalType.SYSTEM_DEFAULT) + + checkDataTypeFromJson(DateType) + checkDataTypeFromDDL(DateType) + + checkDataTypeFromJson(TimestampType) + checkDataTypeFromDDL(TimestampType) + + checkDataTypeFromJson(StringType) + checkDataTypeFromDDL(StringType) + + checkDataTypeFromJson(BinaryType) + checkDataTypeFromDDL(BinaryType) + + checkDataTypeFromJson(ArrayType(DoubleType, true)) + checkDataTypeFromDDL(ArrayType(DoubleType, true)) + + checkDataTypeFromJson(ArrayType(StringType, false)) + checkDataTypeFromDDL(ArrayType(StringType, false)) + + checkDataTypeFromJson(MapType(IntegerType, StringType, true)) + checkDataTypeFromDDL(MapType(IntegerType, StringType, true)) + + checkDataTypeFromJson(MapType(IntegerType, ArrayType(DoubleType), false)) + checkDataTypeFromDDL(MapType(IntegerType, ArrayType(DoubleType), false)) val metadata = new MetadataBuilder() .putString("name", "age") @@ -201,7 +243,8 @@ class DataTypeSuite extends SparkFunSuite { StructField("a", IntegerType, nullable = true), StructField("b", ArrayType(DoubleType), nullable = false), StructField("c", DoubleType, nullable = false, metadata))) - checkDataTypeJsonRepr(structType) + checkDataTypeFromJson(structType) + checkDataTypeFromDDL(structType) def checkDefaultSize(dataType: DataType, expectedDefaultSize: Int): Unit = { test(s"Check the default size of $dataType") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index acdb8e2d3edc..0f9203065ef0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -21,6 +21,7 @@ import scala.collection.JavaConverters._ import scala.language.implicitConversions import scala.reflect.runtime.universe.{typeTag, TypeTag} import scala.util.Try +import scala.util.control.NonFatal import org.apache.spark.annotation.{Experimental, InterfaceStability} import org.apache.spark.sql.catalyst.ScalaReflection @@ -3055,13 +3056,21 @@ object functions { * with the specified schema. Returns `null`, in the case of an unparseable string. * * @param e a string column containing JSON data. - * @param schema the schema to use when parsing the json string as a json string + * @param schema the schema to use when parsing the json string as a json string. In Spark 2.1, + * the user-provided schema has to be in JSON format. Since Spark 2.2, the DDL + * format is also supported for the schema. * * @group collection_funcs * @since 2.1.0 */ - def from_json(e: Column, schema: String, options: java.util.Map[String, String]): Column = - from_json(e, DataType.fromJson(schema), options) + def from_json(e: Column, schema: String, options: java.util.Map[String, String]): Column = { + val dataType = try { + DataType.fromJson(schema) + } catch { + case NonFatal(_) => StructType.fromDDL(schema) + } + from_json(e, dataType, options) + } /** * (Scala-specific) Converts a column containing a `StructType` or `ArrayType` of `StructType`s diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala index 170c238c5343..8465e8d036a6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala @@ -156,6 +156,13 @@ class JsonFunctionsSuite extends QueryTest with SharedSQLContext { Row(Seq(Row(1, "a"), Row(2, null), Row(null, null)))) } + test("from_json uses DDL strings for defining a schema") { + val df = Seq("""{"a": 1, "b": "haa"}""").toDS() + checkAnswer( + df.select(from_json($"value", "a INT, b STRING", new java.util.HashMap[String, String]())), + Row(Row(1, "haa")) :: Nil) + } + test("to_json - struct") { val df = Seq(Tuple1(Tuple1(1))).toDF("a") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala index 1607c97cd6ac..9f4009bfe402 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala @@ -21,7 +21,7 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} -import org.apache.spark.sql.{sources, Row, SparkSession} +import org.apache.spark.sql.{sources, SparkSession} import org.apache.spark.sql.catalyst.{expressions, InternalRow} import org.apache.spark.sql.catalyst.expressions.{Cast, Expression, GenericInternalRow, InterpretedPredicate, InterpretedProjection, JoinedRow, Literal} import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection From 5c8ef376e874497766ba0cc4d97429e33a3d9c61 Mon Sep 17 00:00:00 2001 From: Xiao Li Date: Wed, 29 Mar 2017 12:43:22 -0700 Subject: [PATCH 163/512] [SPARK-17075][SQL][FOLLOWUP] Add Estimation of Constant Literal ### What changes were proposed in this pull request? `FalseLiteral` and `TrueLiteral` should have been eliminated by optimizer rule `BooleanSimplification`, but null literals might be added by optimizer rule `NullPropagation`. For safety, our filter estimation should handle all the eligible literal cases. Our optimizer rule BooleanSimplification is unable to remove the null literal in many cases. For example, `a < 0 or null`. Thus, we need to handle null literal in filter estimation. `Not` can be pushed down below `And` and `Or`. Then, we could see two consecutive `Not`, which need to be collapsed into one. Because of the limited expression support for filter estimation, we just need to handle the case `Not(null)` for avoiding incorrect error due to the boolean operation on null. For details, see below matrix. ``` not NULL = NULL NULL or false = NULL NULL or true = true NULL or NULL = NULL NULL and false = false NULL and true = NULL NULL and NULL = NULL ``` ### How was this patch tested? Added the test cases. Author: Xiao Li Closes #17446 from gatorsmile/constantFilterEstimation. --- .../statsEstimation/FilterEstimation.scala | 39 ++++++++- .../FilterEstimationSuite.scala | 87 +++++++++++++++++++ 2 files changed, 124 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala index f14df93160b7..b32374c5742e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala @@ -24,6 +24,7 @@ import scala.math.BigDecimal.RoundingMode import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.CatalystConf import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral} import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Filter, LeafNode, Statistics} import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ @@ -104,12 +105,23 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo val percent2 = calculateFilterSelectivity(cond2, update = false).getOrElse(1.0) Some(percent1 + percent2 - (percent1 * percent2)) + // Not-operator pushdown case Not(And(cond1, cond2)) => calculateFilterSelectivity(Or(Not(cond1), Not(cond2)), update = false) + // Not-operator pushdown case Not(Or(cond1, cond2)) => calculateFilterSelectivity(And(Not(cond1), Not(cond2)), update = false) + // Collapse two consecutive Not operators which could be generated after Not-operator pushdown + case Not(Not(cond)) => + calculateFilterSelectivity(cond, update = false) + + // The foldable Not has been processed in the ConstantFolding rule + // This is a top-down traversal. The Not could be pushed down by the above two cases. + case Not(l @ Literal(null, _)) => + calculateSingleCondition(l, update = false) + case Not(cond) => calculateFilterSelectivity(cond, update = false) match { case Some(percent) => Some(1.0 - percent) @@ -134,13 +146,16 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo */ def calculateSingleCondition(condition: Expression, update: Boolean): Option[Double] = { condition match { + case l: Literal => + evaluateLiteral(l) + // For evaluateBinary method, we assume the literal on the right side of an operator. // So we will change the order if not. // EqualTo/EqualNullSafe does not care about the order - case op @ Equality(ar: Attribute, l: Literal) => + case Equality(ar: Attribute, l: Literal) => evaluateEquality(ar, l, update) - case op @ Equality(l: Literal, ar: Attribute) => + case Equality(l: Literal, ar: Attribute) => evaluateEquality(ar, l, update) case op @ LessThan(ar: Attribute, l: Literal) => @@ -342,6 +357,26 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo } + /** + * Returns a percentage of rows meeting a Literal expression. + * This method evaluates all the possible literal cases in Filter. + * + * FalseLiteral and TrueLiteral should be eliminated by optimizer, but null literal might be added + * by optimizer rule NullPropagation. For safety, we handle all the cases here. + * + * @param literal a literal value (or constant) + * @return an optional double value to show the percentage of rows meeting a given condition + */ + def evaluateLiteral(literal: Literal): Option[Double] = { + literal match { + case Literal(null, _) => Some(0.0) + case FalseLiteral => Some(0.0) + case TrueLiteral => Some(1.0) + // Ideally, we should not hit the following branch + case _ => None + } + } + /** * Returns a percentage of rows meeting "IN" operator expression. * This method evaluates the equality predicate for all data types. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala index 07abe1ed2853..1966c96c0529 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.statsEstimation import java.sql.Date import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral} import org.apache.spark.sql.catalyst.plans.LeftOuter import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Filter, Join, Statistics} import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils._ @@ -76,6 +77,82 @@ class FilterEstimationSuite extends StatsEstimationTestBase { attrDouble -> colStatDouble, attrString -> colStatString)) + test("true") { + validateEstimatedStats( + Filter(TrueLiteral, childStatsTestPlan(Seq(attrInt), 10L)), + Seq(attrInt -> colStatInt), + expectedRowCount = 10) + } + + test("false") { + validateEstimatedStats( + Filter(FalseLiteral, childStatsTestPlan(Seq(attrInt), 10L)), + Nil, + expectedRowCount = 0) + } + + test("null") { + validateEstimatedStats( + Filter(Literal(null, IntegerType), childStatsTestPlan(Seq(attrInt), 10L)), + Nil, + expectedRowCount = 0) + } + + test("Not(null)") { + validateEstimatedStats( + Filter(Not(Literal(null, IntegerType)), childStatsTestPlan(Seq(attrInt), 10L)), + Nil, + expectedRowCount = 0) + } + + test("Not(Not(null))") { + validateEstimatedStats( + Filter(Not(Not(Literal(null, IntegerType))), childStatsTestPlan(Seq(attrInt), 10L)), + Nil, + expectedRowCount = 0) + } + + test("cint < 3 AND null") { + val condition = And(LessThan(attrInt, Literal(3)), Literal(null, IntegerType)) + validateEstimatedStats( + Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)), + Nil, + expectedRowCount = 0) + } + + test("cint < 3 OR null") { + val condition = Or(LessThan(attrInt, Literal(3)), Literal(null, IntegerType)) + val m = Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)).stats(conf) + validateEstimatedStats( + Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)), + Seq(attrInt -> colStatInt), + expectedRowCount = 3) + } + + test("Not(cint < 3 AND null)") { + val condition = Not(And(LessThan(attrInt, Literal(3)), Literal(null, IntegerType))) + validateEstimatedStats( + Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)), + Seq(attrInt -> colStatInt), + expectedRowCount = 8) + } + + test("Not(cint < 3 OR null)") { + val condition = Not(Or(LessThan(attrInt, Literal(3)), Literal(null, IntegerType))) + validateEstimatedStats( + Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)), + Nil, + expectedRowCount = 0) + } + + test("Not(cint < 3 AND Not(null))") { + val condition = Not(And(LessThan(attrInt, Literal(3)), Not(Literal(null, IntegerType)))) + validateEstimatedStats( + Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)), + Seq(attrInt -> colStatInt), + expectedRowCount = 8) + } + test("cint = 2") { validateEstimatedStats( Filter(EqualTo(attrInt, Literal(2)), childStatsTestPlan(Seq(attrInt), 10L)), @@ -163,6 +240,16 @@ class FilterEstimationSuite extends StatsEstimationTestBase { expectedRowCount = 10) } + test("cint IS NOT NULL && null") { + // 'cint < null' will be optimized to 'cint IS NOT NULL && null'. + // More similar cases can be found in the Optimizer NullPropagation. + val condition = And(IsNotNull(attrInt), Literal(null, IntegerType)) + validateEstimatedStats( + Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)), + Nil, + expectedRowCount = 0) + } + test("cint > 3 AND cint <= 6") { val condition = And(GreaterThan(attrInt, Literal(3)), LessThanOrEqual(attrInt, Literal(6))) validateEstimatedStats( From fe1d6b05d47e384e3710ae428db499e89697267f Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Wed, 29 Mar 2017 15:23:24 -0700 Subject: [PATCH 164/512] [SPARK-20120][SQL] spark-sql support silent mode ## What changes were proposed in this pull request? It is similar to Hive silent mode, just show the query result. see: [Hive LanguageManual+Cli](https://cwiki.apache.org/confluence/display/Hive/LanguageManual+Cli) and [the implementation of Hive silent mode](https://github.com/apache/hive/blob/release-1.2.1/ql/src/java/org/apache/hadoop/hive/ql/session/SessionState.java#L948-L950). This PR set the Logger level to `WARN` to get similar result. ## How was this patch tested? manual tests ![manual test spark sql silent mode](https://cloud.githubusercontent.com/assets/5399861/24390165/989b7780-13b9-11e7-8496-6e68f55757e3.gif) Author: Yuming Wang Closes #17449 from wangyum/SPARK-20120. --- .../spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala index 390b9b6d68ca..1bc5c3c62f04 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala @@ -34,6 +34,7 @@ import org.apache.hadoop.hive.ql.Driver import org.apache.hadoop.hive.ql.exec.Utilities import org.apache.hadoop.hive.ql.processors._ import org.apache.hadoop.hive.ql.session.SessionState +import org.apache.log4j.{Level, Logger} import org.apache.thrift.transport.TSocket import org.apache.spark.internal.Logging @@ -275,6 +276,10 @@ private[hive] class SparkSQLCLIDriver extends CliDriver with Logging { private val console = new SessionState.LogHelper(LOG) + if (sessionState.getIsSilent) { + Logger.getRootLogger.setLevel(Level.WARN) + } + private val isRemoteMode = { SparkSQLCLIDriver.isRemoteMode(sessionState) } From dd2e7d528cb7468cdc077403f314c7ee0f214ac5 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Wed, 29 Mar 2017 17:32:01 -0700 Subject: [PATCH 165/512] [SPARK-19088][SQL] Fix 2.10 build. ## What changes were proposed in this pull request? Commit 6c70a38 broke the build for scala 2.10. The commit uses some reflections which are not available in Scala 2.10. This PR fixes them. ## How was this patch tested? Existing tests. Author: Takuya UESHIN Closes #17473 from ueshin/issues/SPARK-19088. --- .../scala/org/apache/spark/sql/catalyst/ScalaReflection.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 1c7720afe1ca..da37eb00dcd9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -307,7 +307,8 @@ object ScalaReflection extends ScalaReflection { } } - val cls = t.dealias.companion.decl(TermName("newBuilder")) match { + val companion = t.normalize.typeSymbol.companionSymbol.typeSignature + val cls = companion.declaration(newTermName("newBuilder")) match { case NoSymbol => classOf[Seq[_]] case _ => mirror.runtimeClass(t.typeSymbol.asClass) } From 22f07fefe11f0147f1e8d83d9b77707640d5dc97 Mon Sep 17 00:00:00 2001 From: bomeng Date: Wed, 29 Mar 2017 18:57:35 -0700 Subject: [PATCH 166/512] [SPARK-20146][SQL] fix comment missing issue for thrift server ## What changes were proposed in this pull request? The column comment was missing while constructing the Hive TableSchema. This fix will preserve the original comment. ## How was this patch tested? I have added a new test case to test the column with/without comment. Author: bomeng Closes #17470 from bomeng/SPARK-20146. --- .../SparkExecuteStatementOperation.scala | 2 +- .../SparkExecuteStatementOperationSuite.scala | 14 +++++++++++++- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala index 517b01f18392..ff3784cab9e2 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala @@ -292,7 +292,7 @@ object SparkExecuteStatementOperation { def getTableSchema(structType: StructType): TableSchema = { val schema = structType.map { field => val attrTypeString = if (field.dataType == NullType) "void" else field.dataType.catalogString - new FieldSchema(field.name, attrTypeString, "") + new FieldSchema(field.name, attrTypeString, field.getComment.getOrElse("")) } new TableSchema(schema.asJava) } diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperationSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperationSuite.scala index 32ded0d254ef..06e398066204 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperationSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperationSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.hive.thriftserver import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.types.{NullType, StructField, StructType} +import org.apache.spark.sql.types.{IntegerType, NullType, StringType, StructField, StructType} class SparkExecuteStatementOperationSuite extends SparkFunSuite { test("SPARK-17112 `select null` via JDBC triggers IllegalArgumentException in ThriftServer") { @@ -30,4 +30,16 @@ class SparkExecuteStatementOperationSuite extends SparkFunSuite { assert(columns.get(0).getType() == org.apache.hive.service.cli.Type.NULL_TYPE) assert(columns.get(1).getType() == org.apache.hive.service.cli.Type.NULL_TYPE) } + + test("SPARK-20146 Comment should be preserved") { + val field1 = StructField("column1", StringType).withComment("comment 1") + val field2 = StructField("column2", IntegerType) + val tableSchema = StructType(Seq(field1, field2)) + val columns = SparkExecuteStatementOperation.getTableSchema(tableSchema).getColumnDescriptors() + assert(columns.size() == 2) + assert(columns.get(0).getType() == org.apache.hive.service.cli.Type.STRING_TYPE) + assert(columns.get(0).getComment() == "comment 1") + assert(columns.get(1).getType() == org.apache.hive.service.cli.Type.INT_TYPE) + assert(columns.get(1).getComment() == "") + } } From 60977889eaecdf28adc6164310eaa5afed488fa1 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 29 Mar 2017 19:06:51 -0700 Subject: [PATCH 167/512] [SPARK-20136][SQL] Add num files and metadata operation timing to scan operator metrics ## What changes were proposed in this pull request? This patch adds explicit metadata operation timing and number of files in data source metrics. Those would be useful to include for performance profiling. Screenshot of a UI with this change (num files and metadata time are new metrics): screen shot 2017-03-29 at 12 29 28 am ## How was this patch tested? N/A Author: Reynold Xin Closes #17465 from rxin/SPARK-20136. --- .../sql/execution/DataSourceScanExec.scala | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index 28156b277f59..239151495f4b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -171,8 +171,20 @@ case class FileSourceScanExec( false } - @transient private lazy val selectedPartitions = - relation.location.listFiles(partitionFilters, dataFilters) + @transient private lazy val selectedPartitions: Seq[PartitionDirectory] = { + val startTime = System.nanoTime() + val ret = relation.location.listFiles(partitionFilters, dataFilters) + val timeTaken = (System.nanoTime() - startTime) / 1000 / 1000 + + metrics("numFiles").add(ret.map(_.files.size.toLong).sum) + metrics("metadataTime").add(timeTaken) + + val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) + SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, + metrics("numFiles") :: metrics("metadataTime") :: Nil) + + ret + } override val (outputPartitioning, outputOrdering): (Partitioning, Seq[SortOrder]) = { val bucketSpec = if (relation.sparkSession.sessionState.conf.bucketingEnabled) { @@ -293,6 +305,8 @@ case class FileSourceScanExec( override lazy val metrics = Map("numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), + "numFiles" -> SQLMetrics.createMetric(sparkContext, "number of files"), + "metadataTime" -> SQLMetrics.createMetric(sparkContext, "metadata time (ms)"), "scanTime" -> SQLMetrics.createTimingMetric(sparkContext, "scan time")) protected override def doExecute(): RDD[InternalRow] = { From 79636054f60dd639e9d326e1328717e97df13304 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Wed, 29 Mar 2017 20:59:48 -0700 Subject: [PATCH 168/512] [SPARK-20148][SQL] Extend the file commit API to allow subscribing to task commit messages ## What changes were proposed in this pull request? The internal FileCommitProtocol interface returns all task commit messages in bulk to the implementation when a job finishes. However, it is sometimes useful to access those messages before the job completes, so that the driver gets incremental progress updates before the job finishes. This adds an `onTaskCommit` listener to the internal api. ## How was this patch tested? Unit tests. cc rxin Author: Eric Liang Closes #17475 from ericl/file-commit-api-ext. --- .../internal/io/FileCommitProtocol.scala | 7 +++++ .../datasources/FileFormatWriter.scala | 22 +++++++++---- .../sql/test/DataFrameReaderWriterSuite.scala | 31 ++++++++++++++++++- 3 files changed, 53 insertions(+), 7 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala b/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala index 2394cf361c33..7efa9416362a 100644 --- a/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala +++ b/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala @@ -121,6 +121,13 @@ abstract class FileCommitProtocol { def deleteWithJob(fs: FileSystem, path: Path, recursive: Boolean): Boolean = { fs.delete(path, recursive) } + + /** + * Called on the driver after a task commits. This can be used to access task commit messages + * before the job has finished. These same task commit messages will be passed to commitJob() + * if the entire job succeeds. + */ + def onTaskCommit(taskCommit: TaskCommitMessage): Unit = {} } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala index 7957224ce48b..bda64d4b91bb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala @@ -80,6 +80,9 @@ object FileFormatWriter extends Logging { """.stripMargin) } + /** The result of a successful write task. */ + private case class WriteTaskResult(commitMsg: TaskCommitMessage, updatedPartitions: Set[String]) + /** * Basic work flow of this command is: * 1. Driver side setup, including output committer initialization and data source specific @@ -172,8 +175,9 @@ object FileFormatWriter extends Logging { global = false, child = queryExecution.executedPlan).execute() } - - val ret = sparkSession.sparkContext.runJob(rdd, + val ret = new Array[WriteTaskResult](rdd.partitions.length) + sparkSession.sparkContext.runJob( + rdd, (taskContext: TaskContext, iter: Iterator[InternalRow]) => { executeTask( description = description, @@ -182,10 +186,16 @@ object FileFormatWriter extends Logging { sparkAttemptNumber = taskContext.attemptNumber(), committer, iterator = iter) + }, + 0 until rdd.partitions.length, + (index, res: WriteTaskResult) => { + committer.onTaskCommit(res.commitMsg) + ret(index) = res }) - val commitMsgs = ret.map(_._1) - val updatedPartitions = ret.flatMap(_._2).distinct.map(PartitioningUtils.parsePathFragment) + val commitMsgs = ret.map(_.commitMsg) + val updatedPartitions = ret.flatMap(_.updatedPartitions) + .distinct.map(PartitioningUtils.parsePathFragment) committer.commitJob(job, commitMsgs) logInfo(s"Job ${job.getJobID} committed.") @@ -205,7 +215,7 @@ object FileFormatWriter extends Logging { sparkPartitionId: Int, sparkAttemptNumber: Int, committer: FileCommitProtocol, - iterator: Iterator[InternalRow]): (TaskCommitMessage, Set[String]) = { + iterator: Iterator[InternalRow]): WriteTaskResult = { val jobId = SparkHadoopWriterUtils.createJobID(new Date, sparkStageId) val taskId = new TaskID(jobId, TaskType.MAP, sparkPartitionId) @@ -238,7 +248,7 @@ object FileFormatWriter extends Logging { // Execute the task to write rows out and commit the task. val outputPartitions = writeTask.execute(iterator) writeTask.releaseResources() - (committer.commitTask(taskAttemptContext), outputPartitions) + WriteTaskResult(committer.commitTask(taskAttemptContext), outputPartitions) })(catchBlock = { // If there is an error, release resource and then abort the task try { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala index 8287776f8f55..7c71e7280c6d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala @@ -18,9 +18,12 @@ package org.apache.spark.sql.test import java.io.File +import java.util.concurrent.ConcurrentLinkedQueue import org.scalatest.BeforeAndAfter +import org.apache.spark.internal.io.FileCommitProtocol.TaskCommitMessage +import org.apache.spark.internal.io.HadoopMapReduceCommitProtocol import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.sources._ @@ -41,7 +44,6 @@ object LastOptions { } } - /** Dummy provider. */ class DefaultSource extends RelationProvider @@ -107,6 +109,20 @@ class DefaultSourceWithoutUserSpecifiedSchema } } +object MessageCapturingCommitProtocol { + val commitMessages = new ConcurrentLinkedQueue[TaskCommitMessage]() +} + +class MessageCapturingCommitProtocol(jobId: String, path: String) + extends HadoopMapReduceCommitProtocol(jobId, path) { + + // captures commit messages for testing + override def onTaskCommit(msg: TaskCommitMessage): Unit = { + MessageCapturingCommitProtocol.commitMessages.offer(msg) + } +} + + class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with BeforeAndAfter { import testImplicits._ @@ -291,6 +307,19 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with Be Option(dir).map(spark.read.format("org.apache.spark.sql.test").load) } + test("write path implements onTaskCommit API correctly") { + withSQLConf( + "spark.sql.sources.commitProtocolClass" -> + classOf[MessageCapturingCommitProtocol].getCanonicalName) { + withTempDir { dir => + val path = dir.getCanonicalPath + MessageCapturingCommitProtocol.commitMessages.clear() + spark.range(10).repartition(10).write.mode("overwrite").parquet(path) + assert(MessageCapturingCommitProtocol.commitMessages.size() == 10) + } + } + } + test("read a data source that does not extend SchemaRelationProvider") { val dfReader = spark.read .option("from", "1") From 471de5db53ed77711523a3f016d6e9c530b651e5 Mon Sep 17 00:00:00 2001 From: "wm624@hotmail.com" Date: Wed, 29 Mar 2017 21:38:26 -0700 Subject: [PATCH 169/512] [MINOR][SPARKR] Add run command comment in examples ## What changes were proposed in this pull request? There are two examples in r folder missing the run commands. In this PR, I just add the missing comment, which is consistent with other examples. ## How was this patch tested? Manual test. Author: wm624@hotmail.com Closes #17474 from wangmiao1981/stat. --- examples/src/main/r/RSparkSQLExample.R | 3 +++ examples/src/main/r/dataframe.R | 3 +++ 2 files changed, 6 insertions(+) diff --git a/examples/src/main/r/RSparkSQLExample.R b/examples/src/main/r/RSparkSQLExample.R index e647f0e1e9f1..3734568d872d 100644 --- a/examples/src/main/r/RSparkSQLExample.R +++ b/examples/src/main/r/RSparkSQLExample.R @@ -15,6 +15,9 @@ # limitations under the License. # +# To run this example use +# ./bin/spark-submit examples/src/main/r/RSparkSQLExample.R + library(SparkR) # $example on:init_session$ diff --git a/examples/src/main/r/dataframe.R b/examples/src/main/r/dataframe.R index 82b85f2f590f..311350497f87 100644 --- a/examples/src/main/r/dataframe.R +++ b/examples/src/main/r/dataframe.R @@ -15,6 +15,9 @@ # limitations under the License. # +# To run this example use +# ./bin/spark-submit examples/src/main/r/dataframe.R + library(SparkR) # Initialize SparkSession From edc87d76efea7b4d19d9d0c4ddba274a3ccb8752 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Thu, 30 Mar 2017 10:39:57 +0100 Subject: [PATCH 170/512] [SPARK-20107][DOC] Add spark.hadoop.mapreduce.fileoutputcommitter.algorithm.version option to configuration.md ## What changes were proposed in this pull request? Add `spark.hadoop.mapreduce.fileoutputcommitter.algorithm.version` option to `configuration.md`. Set `spark.hadoop.mapreduce.fileoutputcommitter.algorithm.version=2` can speed up [HadoopMapReduceCommitProtocol.commitJob](https://github.com/apache/spark/blob/v2.1.0/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala#L121) for many output files. All cloudera's hadoop 2.6.0-cdh5.4.0 or higher versions(see: https://github.com/cloudera/hadoop-common/commit/1c1236182304d4075276c00c4592358f428bc433 and https://github.com/cloudera/hadoop-common/commit/16b2de27321db7ce2395c08baccfdec5562017f0) and apache's hadoop 2.7.0 or higher versions support this improvement. More see: 1. [MAPREDUCE-4815](https://issues.apache.org/jira/browse/MAPREDUCE-4815): Speed up FileOutputCommitter#commitJob for many output files. 2. [MAPREDUCE-6406](https://issues.apache.org/jira/browse/MAPREDUCE-6406): Update the default version for the property mapreduce.fileoutputcommitter.algorithm.version to 2. ## How was this patch tested? Manual test and exist tests. Author: Yuming Wang Closes #17442 from wangyum/SPARK-20107. --- docs/configuration.md | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/docs/configuration.md b/docs/configuration.md index 4729f1b0404c..a9753925407d 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -1137,6 +1137,15 @@ Apart from these, the following properties are also available, and may be useful mapping has high overhead for blocks close to or below the page size of the operating system. + + spark.hadoop.mapreduce.fileoutputcommitter.algorithm.version + 1 + + The file output committer algorithm version, valid algorithm version number: 1 or 2. + Version 2 may have better performance, but version 1 may handle failures better in certain situations, + as per MAPREDUCE-4815. + + ### Networking From b454d4402e5ee7d1a7385d1fe3737581f84d2c72 Mon Sep 17 00:00:00 2001 From: Shubham Chopra Date: Thu, 30 Mar 2017 22:21:57 +0800 Subject: [PATCH 171/512] [SPARK-15354][CORE] Topology aware block replication strategies ## What changes were proposed in this pull request? Implementations of strategies for resilient block replication for different resource managers that replicate the 3-replica strategy used by HDFS, where the first replica is on an executor, the second replica within the same rack as the executor and a third replica on a different rack. The implementation involves providing two pluggable classes, one running in the driver that provides topology information for every host at cluster start and the second prioritizing a list of peer BlockManagerIds. The prioritization itself can be thought of an optimization problem to find a minimal set of peers that satisfy certain objectives and replicating to these peers first. The objectives can be used to express richer constraints over and above HDFS like 3-replica strategy. ## How was this patch tested? This patch was tested with unit tests for storage, along with new unit tests to verify prioritization behaviour. Author: Shubham Chopra Closes #13932 from shubhamchopra/PrioritizerStrategy. --- .../apache/spark/storage/BlockManager.scala | 3 - .../storage/BlockReplicationPolicy.scala | 145 ++++++++++++++++-- .../BlockManagerReplicationSuite.scala | 33 +++- .../storage/BlockReplicationPolicySuite.scala | 73 +++++++-- 4 files changed, 222 insertions(+), 32 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index fcda9fa65303..46a078b2f9f9 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -49,7 +49,6 @@ import org.apache.spark.unsafe.Platform import org.apache.spark.util._ import org.apache.spark.util.io.ChunkedByteBuffer - /* Class for returning a fetched block and associated metrics. */ private[spark] class BlockResult( val data: Iterator[Any], @@ -1258,7 +1257,6 @@ private[spark] class BlockManager( replication = 1) val numPeersToReplicateTo = level.replication - 1 - val startTime = System.nanoTime var peersReplicatedTo = mutable.HashSet.empty ++ existingReplicas @@ -1313,7 +1311,6 @@ private[spark] class BlockManager( numPeersToReplicateTo - peersReplicatedTo.size) } } - logDebug(s"Replicating $blockId of ${data.size} bytes to " + s"${peersReplicatedTo.size} peer(s) took ${(System.nanoTime - startTime) / 1e6} ms") if (peersReplicatedTo.size < numPeersToReplicateTo) { diff --git a/core/src/main/scala/org/apache/spark/storage/BlockReplicationPolicy.scala b/core/src/main/scala/org/apache/spark/storage/BlockReplicationPolicy.scala index bb8a684b4c7a..353eac60df17 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockReplicationPolicy.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockReplicationPolicy.scala @@ -53,6 +53,46 @@ trait BlockReplicationPolicy { numReplicas: Int): List[BlockManagerId] } +object BlockReplicationUtils { + // scalastyle:off line.size.limit + /** + * Uses sampling algorithm by Robert Floyd. Finds a random sample in O(n) while + * minimizing space usage. Please see + * here. + * + * @param n total number of indices + * @param m number of samples needed + * @param r random number generator + * @return list of m random unique indices + */ + // scalastyle:on line.size.limit + private def getSampleIds(n: Int, m: Int, r: Random): List[Int] = { + val indices = (n - m + 1 to n).foldLeft(mutable.LinkedHashSet.empty[Int]) {case (set, i) => + val t = r.nextInt(i) + 1 + if (set.contains(t)) set + i else set + t + } + indices.map(_ - 1).toList + } + + /** + * Get a random sample of size m from the elems + * + * @param elems + * @param m number of samples needed + * @param r random number generator + * @tparam T + * @return a random list of size m. If there are fewer than m elements in elems, we just + * randomly shuffle elems + */ + def getRandomSample[T](elems: Seq[T], m: Int, r: Random): List[T] = { + if (elems.size > m) { + getSampleIds(elems.size, m, r).map(elems(_)) + } else { + r.shuffle(elems).toList + } + } +} + @DeveloperApi class RandomBlockReplicationPolicy extends BlockReplicationPolicy @@ -67,6 +107,7 @@ class RandomBlockReplicationPolicy * @param peersReplicatedTo Set of peers already replicated to * @param blockId BlockId of the block being replicated. This can be used as a source of * randomness if needed. + * @param numReplicas Number of peers we need to replicate to * @return A prioritized list of peers. Lower the index of a peer, higher its priority */ override def prioritize( @@ -78,7 +119,7 @@ class RandomBlockReplicationPolicy val random = new Random(blockId.hashCode) logDebug(s"Input peers : ${peers.mkString(", ")}") val prioritizedPeers = if (peers.size > numReplicas) { - getSampleIds(peers.size, numReplicas, random).map(peers(_)) + BlockReplicationUtils.getRandomSample(peers, numReplicas, random) } else { if (peers.size < numReplicas) { logWarning(s"Expecting ${numReplicas} replicas with only ${peers.size} peer/s.") @@ -88,26 +129,96 @@ class RandomBlockReplicationPolicy logDebug(s"Prioritized peers : ${prioritizedPeers.mkString(", ")}") prioritizedPeers } +} + +@DeveloperApi +class BasicBlockReplicationPolicy + extends BlockReplicationPolicy + with Logging { - // scalastyle:off line.size.limit /** - * Uses sampling algorithm by Robert Floyd. Finds a random sample in O(n) while - * minimizing space usage. Please see - * here. + * Method to prioritize a bunch of candidate peers of a block manager. This implementation + * replicates the behavior of block replication in HDFS. For a given number of replicas needed, + * we choose a peer within the rack, one outside and remaining blockmanagers are chosen at + * random, in that order till we meet the number of replicas needed. + * This works best with a total replication factor of 3, like HDFS. * - * @param n total number of indices - * @param m number of samples needed - * @param r random number generator - * @return list of m random unique indices + * @param blockManagerId Id of the current BlockManager for self identification + * @param peers A list of peers of a BlockManager + * @param peersReplicatedTo Set of peers already replicated to + * @param blockId BlockId of the block being replicated. This can be used as a source of + * randomness if needed. + * @param numReplicas Number of peers we need to replicate to + * @return A prioritized list of peers. Lower the index of a peer, higher its priority */ - // scalastyle:on line.size.limit - private def getSampleIds(n: Int, m: Int, r: Random): List[Int] = { - val indices = (n - m + 1 to n).foldLeft(Set.empty[Int]) {case (set, i) => - val t = r.nextInt(i) + 1 - if (set.contains(t)) set + i else set + t + override def prioritize( + blockManagerId: BlockManagerId, + peers: Seq[BlockManagerId], + peersReplicatedTo: mutable.HashSet[BlockManagerId], + blockId: BlockId, + numReplicas: Int): List[BlockManagerId] = { + + logDebug(s"Input peers : $peers") + logDebug(s"BlockManagerId : $blockManagerId") + + val random = new Random(blockId.hashCode) + + // if block doesn't have topology info, we can't do much, so we randomly shuffle + // if there is, we see what's needed from peersReplicatedTo and based on numReplicas, + // we choose whats needed + if (blockManagerId.topologyInfo.isEmpty || numReplicas == 0) { + // no topology info for the block. The best we can do is randomly choose peers + BlockReplicationUtils.getRandomSample(peers, numReplicas, random) + } else { + // we have topology information, we see what is left to be done from peersReplicatedTo + val doneWithinRack = peersReplicatedTo.exists(_.topologyInfo == blockManagerId.topologyInfo) + val doneOutsideRack = peersReplicatedTo.exists { p => + p.topologyInfo.isDefined && p.topologyInfo != blockManagerId.topologyInfo + } + + if (doneOutsideRack && doneWithinRack) { + // we are done, we just return a random sample + BlockReplicationUtils.getRandomSample(peers, numReplicas, random) + } else { + // we separate peers within and outside rack + val (inRackPeers, outOfRackPeers) = peers + .filter(_.host != blockManagerId.host) + .partition(_.topologyInfo == blockManagerId.topologyInfo) + + val peerWithinRack = if (doneWithinRack) { + // we are done with in-rack replication, so don't need anymore peers + Seq.empty + } else { + if (inRackPeers.isEmpty) { + Seq.empty + } else { + Seq(inRackPeers(random.nextInt(inRackPeers.size))) + } + } + + val peerOutsideRack = if (doneOutsideRack || numReplicas - peerWithinRack.size <= 0) { + Seq.empty + } else { + if (outOfRackPeers.isEmpty) { + Seq.empty + } else { + Seq(outOfRackPeers(random.nextInt(outOfRackPeers.size))) + } + } + + val priorityPeers = peerWithinRack ++ peerOutsideRack + val numRemainingPeers = numReplicas - priorityPeers.size + val remainingPeers = if (numRemainingPeers > 0) { + val rPeers = peers.filter(p => !priorityPeers.contains(p)) + BlockReplicationUtils.getRandomSample(rPeers, numRemainingPeers, random) + } else { + Seq.empty + } + + (priorityPeers ++ remainingPeers).toList + } + } - // we shuffle the result to ensure a random arrangement within the sample - // to avoid any bias from set implementations - r.shuffle(indices.map(_ - 1).toList) } + } diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala index d5715f8469f7..13020acdd3db 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala @@ -28,6 +28,7 @@ import org.scalatest.concurrent.Eventually._ import org.apache.spark._ import org.apache.spark.broadcast.BroadcastManager +import org.apache.spark.internal.Logging import org.apache.spark.memory.UnifiedMemoryManager import org.apache.spark.network.BlockTransferService import org.apache.spark.network.netty.NettyBlockTransferService @@ -36,6 +37,7 @@ import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.serializer.{KryoSerializer, SerializerManager} import org.apache.spark.shuffle.sort.SortShuffleManager import org.apache.spark.storage.StorageLevel._ +import org.apache.spark.util.Utils trait BlockManagerReplicationBehavior extends SparkFunSuite with Matchers @@ -43,6 +45,7 @@ trait BlockManagerReplicationBehavior extends SparkFunSuite with LocalSparkContext { val conf: SparkConf + protected var rpcEnv: RpcEnv = null protected var master: BlockManagerMaster = null protected lazy val securityMgr = new SecurityManager(conf) @@ -55,7 +58,6 @@ trait BlockManagerReplicationBehavior extends SparkFunSuite protected val allStores = new ArrayBuffer[BlockManager] // Reuse a serializer across tests to avoid creating a new thread-local buffer on each test - protected lazy val serializer = new KryoSerializer(conf) // Implicitly convert strings to BlockIds for test clarity. @@ -471,7 +473,7 @@ class BlockManagerProactiveReplicationSuite extends BlockManagerReplicationBehav conf.set("spark.storage.replication.proactive", "true") conf.set("spark.storage.exceptionOnPinLeak", "true") - (2 to 5).foreach{ i => + (2 to 5).foreach { i => test(s"proactive block replication - $i replicas - ${i - 1} block manager deletions") { testProactiveReplication(i) } @@ -524,3 +526,30 @@ class BlockManagerProactiveReplicationSuite extends BlockManagerReplicationBehav } } } + +class DummyTopologyMapper(conf: SparkConf) extends TopologyMapper(conf) with Logging { + // number of racks to test with + val numRacks = 3 + + /** + * Gets the topology information given the host name + * + * @param hostname Hostname + * @return random topology + */ + override def getTopologyForHost(hostname: String): Option[String] = { + Some(s"/Rack-${Utils.random.nextInt(numRacks)}") + } +} + +class BlockManagerBasicStrategyReplicationSuite extends BlockManagerReplicationBehavior { + val conf: SparkConf = new SparkConf(false).set("spark.app.id", "test") + conf.set("spark.kryoserializer.buffer", "1m") + conf.set( + "spark.storage.replication.policy", + classOf[BasicBlockReplicationPolicy].getName) + conf.set( + "spark.storage.replication.topologyMapper", + classOf[DummyTopologyMapper].getName) +} + diff --git a/core/src/test/scala/org/apache/spark/storage/BlockReplicationPolicySuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockReplicationPolicySuite.scala index 800c3899f1a7..ecad0f5352e5 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockReplicationPolicySuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockReplicationPolicySuite.scala @@ -18,34 +18,34 @@ package org.apache.spark.storage import scala.collection.mutable +import scala.util.Random import org.scalatest.{BeforeAndAfter, Matchers} import org.apache.spark.{LocalSparkContext, SparkFunSuite} -class BlockReplicationPolicySuite extends SparkFunSuite +class RandomBlockReplicationPolicyBehavior extends SparkFunSuite with Matchers with BeforeAndAfter with LocalSparkContext { // Implicitly convert strings to BlockIds for test clarity. - private implicit def StringToBlockId(value: String): BlockId = new TestBlockId(value) + protected implicit def StringToBlockId(value: String): BlockId = new TestBlockId(value) + val replicationPolicy: BlockReplicationPolicy = new RandomBlockReplicationPolicy + + val blockId = "test-block" /** * Test if we get the required number of peers when using random sampling from - * RandomBlockReplicationPolicy + * BlockReplicationPolicy */ - test(s"block replication - random block replication policy") { + test("block replication - random block replication policy") { val numBlockManagers = 10 val storeSize = 1000 - val blockManagers = (1 to numBlockManagers).map { i => - BlockManagerId(s"store-$i", "localhost", 1000 + i, None) - } + val blockManagers = generateBlockManagerIds(numBlockManagers, Seq("/Rack-1")) val candidateBlockManager = BlockManagerId("test-store", "localhost", 1000, None) - val replicationPolicy = new RandomBlockReplicationPolicy - val blockId = "test-block" - (1 to 10).foreach {numReplicas => + (1 to 10).foreach { numReplicas => logDebug(s"Num replicas : $numReplicas") val randomPeers = replicationPolicy.prioritize( candidateBlockManager, @@ -68,7 +68,60 @@ class BlockReplicationPolicySuite extends SparkFunSuite logDebug(s"Random peers : ${secondPass.mkString(", ")}") assert(secondPass.toSet.size === numReplicas) } + } + + protected def generateBlockManagerIds(count: Int, racks: Seq[String]): Seq[BlockManagerId] = { + (1 to count).map{i => + BlockManagerId(s"Exec-$i", s"Host-$i", 10000 + i, Some(racks(Random.nextInt(racks.size)))) + } + } +} + +class TopologyAwareBlockReplicationPolicyBehavior extends RandomBlockReplicationPolicyBehavior { + override val replicationPolicy = new BasicBlockReplicationPolicy + + test("All peers in the same rack") { + val racks = Seq("/default-rack") + val numBlockManager = 10 + (1 to 10).foreach {numReplicas => + val peers = generateBlockManagerIds(numBlockManager, racks) + val blockManager = BlockManagerId("Driver", "Host-driver", 10001, Some(racks.head)) + + val prioritizedPeers = replicationPolicy.prioritize( + blockManager, + peers, + mutable.HashSet.empty, + blockId, + numReplicas + ) + assert(prioritizedPeers.toSet.size == numReplicas) + assert(prioritizedPeers.forall(p => p.host != blockManager.host)) + } } + test("Peers in 2 racks") { + val racks = Seq("/Rack-1", "/Rack-2") + (1 to 10).foreach {numReplicas => + val peers = generateBlockManagerIds(10, racks) + val blockManager = BlockManagerId("Driver", "Host-driver", 9001, Some(racks.head)) + + val prioritizedPeers = replicationPolicy.prioritize( + blockManager, + peers, + mutable.HashSet.empty, + blockId, + numReplicas + ) + + assert(prioritizedPeers.toSet.size == numReplicas) + val priorityPeers = prioritizedPeers.take(2) + assert(priorityPeers.forall(p => p.host != blockManager.host)) + if(numReplicas > 1) { + // both these conditions should be satisfied when numReplicas > 1 + assert(priorityPeers.exists(p => p.topologyInfo == blockManager.topologyInfo)) + assert(priorityPeers.exists(p => p.topologyInfo != blockManager.topologyInfo)) + } + } + } } From 0197262a358fd174a188f8246ae777e53157610e Mon Sep 17 00:00:00 2001 From: Jacek Laskowski Date: Thu, 30 Mar 2017 16:07:27 +0100 Subject: [PATCH 172/512] [DOCS] Docs-only improvements MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …adoc ## What changes were proposed in this pull request? Use recommended values for row boundaries in Window's scaladoc, i.e. `Window.unboundedPreceding`, `Window.unboundedFollowing`, and `Window.currentRow` (that were introduced in 2.1.0). ## How was this patch tested? Local build Author: Jacek Laskowski Closes #17417 from jaceklaskowski/window-expression-scaladoc. --- .../apache/spark/memory/MemoryConsumer.java | 2 -- .../sort/BypassMergeSortShuffleWriter.java | 5 ++-- .../spark/ExecutorAllocationClient.scala | 5 ++-- .../org/apache/spark/scheduler/Task.scala | 2 +- .../apache/spark/serializer/Serializer.scala | 2 +- .../shuffle/BlockStoreShuffleReader.scala | 3 +-- .../shuffle/IndexShuffleBlockResolver.scala | 4 ++-- .../shuffle/sort/SortShuffleManager.scala | 4 ++-- .../org/apache/spark/util/AccumulatorV2.scala | 2 +- .../spark/examples/ml/DataFrameExample.scala | 2 +- .../apache/spark/ml/stat/Correlation.scala | 2 +- .../sql/catalyst/analysis/ResolveHints.scala | 2 +- .../catalyst/encoders/ExpressionEncoder.scala | 6 ++--- .../sql/catalyst/expressions/Expression.scala | 2 +- .../expressions/windowExpressions.scala | 2 +- .../sql/catalyst/optimizer/objects.scala | 2 +- .../sql/catalyst/parser/AstBuilder.scala | 6 ++--- .../spark/sql/catalyst/plans/QueryPlan.scala | 5 ++-- .../catalyst/plans/logical/LogicalPlan.scala | 2 +- .../parser/ExpressionParserSuite.scala | 3 ++- .../scala/org/apache/spark/sql/Column.scala | 18 +++++++-------- .../org/apache/spark/sql/DatasetHolder.scala | 3 ++- .../org/apache/spark/sql/SparkSession.scala | 2 +- .../sql/execution/command/databases.scala | 2 +- .../sql/execution/streaming/Source.scala | 2 +- .../apache/spark/sql/expressions/Window.scala | 23 ++++++++++--------- .../spark/sql/expressions/WindowSpec.scala | 20 ++++++++-------- .../org/apache/spark/sql/functions.scala | 2 +- .../sql/hive/HiveSessionStateBuilder.scala | 2 +- .../scheduler/InputInfoTracker.scala | 2 +- 30 files changed, 68 insertions(+), 71 deletions(-) diff --git a/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java b/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java index fc1f3a80239b..48cf4b9455e4 100644 --- a/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java +++ b/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java @@ -60,8 +60,6 @@ protected long getUsed() { /** * Force spill during building. - * - * For testing. */ public void spill() throws IOException { spill(Long.MAX_VALUE, this); diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java index 4a15559e55cb..323a5d3c5283 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java @@ -52,8 +52,7 @@ * This class implements sort-based shuffle's hash-style shuffle fallback path. This write path * writes incoming records to separate files, one file per reduce partition, then concatenates these * per-partition files to form a single output file, regions of which are served to reducers. - * Records are not buffered in memory. This is essentially identical to - * {@link org.apache.spark.shuffle.hash.HashShuffleWriter}, except that it writes output in a format + * Records are not buffered in memory. It writes output in a format * that can be served / consumed via {@link org.apache.spark.shuffle.IndexShuffleBlockResolver}. *

      * This write path is inefficient for shuffles with large numbers of reduce partitions because it @@ -61,7 +60,7 @@ * {@link SortShuffleManager} only selects this write path when *

        *
      • no Ordering is specified,
      • - *
      • no Aggregator is specific, and
      • + *
      • no Aggregator is specified, and
      • *
      • the number of partitions is less than * spark.shuffle.sort.bypassMergeThreshold.
      • *
      diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationClient.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationClient.scala index e4b9f8111efc..9112d93a86b2 100644 --- a/core/src/main/scala/org/apache/spark/ExecutorAllocationClient.scala +++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationClient.scala @@ -71,13 +71,12 @@ private[spark] trait ExecutorAllocationClient { /** * Request that the cluster manager kill every executor on the specified host. - * Results in a call to killExecutors for each executor on the host, with the replace - * and force arguments set to true. + * * @return whether the request is acknowledged by the cluster manager. */ def killExecutorsOnHost(host: String): Boolean - /** + /** * Request that the cluster manager kill the specified executor. * @return whether the request is acknowledged by the cluster manager. */ diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index 46ef23f316a6..7fd2918960cd 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -149,7 +149,7 @@ private[spark] abstract class Task[T]( def preferredLocations: Seq[TaskLocation] = Nil - // Map output tracker epoch. Will be set by TaskScheduler. + // Map output tracker epoch. Will be set by TaskSetManager. var epoch: Long = -1 // Task context, to be initialized in run(). diff --git a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala index 008b0387899f..01bbda0b5e6b 100644 --- a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala @@ -77,7 +77,7 @@ abstract class Serializer { * position = 0 * serOut.write(obj1) * serOut.flush() - * position = # of bytes writen to stream so far + * position = # of bytes written to stream so far * obj1Bytes = output[0:position-1] * serOut.write(obj2) * serOut.flush() diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala index 8b2e26cdd94f..ba3e0e395e95 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala @@ -95,8 +95,7 @@ private[spark] class BlockStoreShuffleReader[K, C]( // Sort the output if there is a sort ordering defined. dep.keyOrdering match { case Some(keyOrd: Ordering[K]) => - // Create an ExternalSorter to sort the data. Note that if spark.shuffle.spill is disabled, - // the ExternalSorter won't spill to disk. + // Create an ExternalSorter to sort the data. val sorter = new ExternalSorter[K, C, C](context, ordering = Some(keyOrd), serializer = dep.serializer) sorter.insertAll(aggregatedIter) diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala index 91858f0912b6..15540485170d 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala @@ -61,7 +61,7 @@ private[spark] class IndexShuffleBlockResolver( /** * Remove data file and index file that contain the output data from one map. - * */ + */ def removeDataByMap(shuffleId: Int, mapId: Int): Unit = { var file = getDataFile(shuffleId, mapId) if (file.exists()) { @@ -132,7 +132,7 @@ private[spark] class IndexShuffleBlockResolver( * replace them with new ones. * * Note: the `lengths` will be updated to match the existing index file if use the existing ones. - * */ + */ def writeIndexFileAndCommit( shuffleId: Int, mapId: Int, diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala index 5e977a16febe..bfb4dc698e32 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala @@ -82,13 +82,13 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager override val shuffleBlockResolver = new IndexShuffleBlockResolver(conf) /** - * Register a shuffle with the manager and obtain a handle for it to pass to tasks. + * Obtains a [[ShuffleHandle]] to pass to tasks. */ override def registerShuffle[K, V, C]( shuffleId: Int, numMaps: Int, dependency: ShuffleDependency[K, V, C]): ShuffleHandle = { - if (SortShuffleWriter.shouldBypassMergeSort(SparkEnv.get.conf, dependency)) { + if (SortShuffleWriter.shouldBypassMergeSort(conf, dependency)) { // If there are fewer than spark.shuffle.sort.bypassMergeThreshold partitions and we don't // need map-side aggregation, then write numPartitions files directly and just concatenate // them at the end. This avoids doing serialization and deserialization twice to merge diff --git a/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala b/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala index 00e0cf257cd4..7479de55140e 100644 --- a/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala +++ b/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala @@ -279,7 +279,7 @@ private[spark] object AccumulatorContext { /** - * An [[AccumulatorV2 accumulator]] for computing sum, count, and averages for 64-bit integers. + * An [[AccumulatorV2 accumulator]] for computing sum, count, and average of 64-bit integers. * * @since 2.0.0 */ diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DataFrameExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DataFrameExample.scala index e07c9a4717c3..0658bddf1696 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/DataFrameExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DataFrameExample.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.{DataFrame, Row, SparkSession} import org.apache.spark.util.Utils /** - * An example of how to use [[org.apache.spark.sql.DataFrame]] for ML. Run with + * An example of how to use [[DataFrame]] for ML. Run with * {{{ * ./bin/run-example ml.DataFrameExample [options] * }}} diff --git a/mllib/src/main/scala/org/apache/spark/ml/stat/Correlation.scala b/mllib/src/main/scala/org/apache/spark/ml/stat/Correlation.scala index a7243ccbf28c..d3c84b77d26a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/stat/Correlation.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/stat/Correlation.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.types.{StructField, StructType} /** * API for correlation functions in MLlib, compatible with Dataframes and Datasets. * - * The functions in this package generalize the functions in [[org.apache.spark.sql.Dataset.stat]] + * The functions in this package generalize the functions in [[org.apache.spark.sql.Dataset#stat]] * to spark.ml's Vector types. */ @Since("2.2.0") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala index 70438eb5912b..920033a9a848 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.trees.CurrentOrigin /** * Collection of rules related to hints. The only hint currently available is broadcast join hint. * - * Note that this is separatedly into two rules because in the future we might introduce new hint + * Note that this is separately into two rules because in the future we might introduce new hint * rules that have different ordering requirements from broadcast. */ object ResolveHints { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index 93fc565a5341..ec003cdc17b8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -229,9 +229,9 @@ case class ExpressionEncoder[T]( // serializer expressions are used to encode an object to a row, while the object is usually an // intermediate value produced inside an operator, not from the output of the child operator. This // is quite different from normal expressions, and `AttributeReference` doesn't work here - // (intermediate value is not an attribute). We assume that all serializer expressions use a same - // `BoundReference` to refer to the object, and throw exception if they don't. - assert(serializer.forall(_.references.isEmpty), "serializer cannot reference to any attributes.") + // (intermediate value is not an attribute). We assume that all serializer expressions use the + // same `BoundReference` to refer to the object, and throw exception if they don't. + assert(serializer.forall(_.references.isEmpty), "serializer cannot reference any attributes.") assert(serializer.flatMap { ser => val boundRefs = ser.collect { case b: BoundReference => b } assert(boundRefs.nonEmpty, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index b93a5d0b7a0e..1db26d9c415a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -491,7 +491,7 @@ abstract class BinaryExpression extends Expression { * A [[BinaryExpression]] that is an operator, with two properties: * * 1. The string representation is "x symbol y", rather than "funcName(x, y)". - * 2. Two inputs are expected to the be same type. If the two inputs have different types, + * 2. Two inputs are expected to be of the same type. If the two inputs have different types, * the analyzer will find the tightest common type and do the proper type casting. */ abstract class BinaryOperator extends BinaryExpression with ExpectsInputTypes { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala index 07d294b10854..b2a3888ff7b0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala @@ -695,7 +695,7 @@ case class DenseRank(children: Seq[Expression]) extends RankLike { * * This documentation has been based upon similar documentation for the Hive and Presto projects. * - * @param children to base the rank on; a change in the value of one the children will trigger a + * @param children to base the rank on; a change in the value of one of the children will trigger a * change in rank. This is an internal parameter and will be assigned by the * Analyser. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala index 174d546e2280..257dbfac8c3e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala @@ -65,7 +65,7 @@ object EliminateSerialization extends Rule[LogicalPlan] { /** * Combines two adjacent [[TypedFilter]]s, which operate on same type object in condition, into one, - * mering the filter functions into one conjunctive function. + * merging the filter functions into one conjunctive function. */ object CombineTypedFilters extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index cd238e05d410..162051a8c0e4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -492,7 +492,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { } /** - * Add an [[Aggregate]] to a logical plan. + * Add an [[Aggregate]] or [[GroupingSets]] to a logical plan. */ private def withAggregation( ctx: AggregationContext, @@ -519,7 +519,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { } /** - * Add a Hint to a logical plan. + * Add a [[Hint]] to a logical plan. */ private def withHints( ctx: HintContext, @@ -545,7 +545,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { } /** - * Create a single relation referenced in a FROM claused. This method is used when a part of the + * Create a single relation referenced in a FROM clause. This method is used when a part of the * join condition is nested, for example: * {{{ * select * from t1 join (t2 cross join t3) on col1 = col2 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 9fd95a4b368c..2d8ec2053a4c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -230,14 +230,15 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT def producedAttributes: AttributeSet = AttributeSet.empty /** - * Attributes that are referenced by expressions but not provided by this nodes children. + * Attributes that are referenced by expressions but not provided by this node's children. * Subclasses should override this method if they produce attributes internally as it is used by * assertions designed to prevent the construction of invalid plans. */ def missingInput: AttributeSet = references -- inputSet -- producedAttributes /** - * Runs [[transform]] with `rule` on all expressions present in this query operator. + * Runs [[transformExpressionsDown]] with `rule` on all expressions present + * in this query operator. * Users should not expect a specific directionality. If a specific directionality is needed, * transformExpressionsDown or transformExpressionsUp should be used. * diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index e22b429aec68..f71a976bd7a2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -32,7 +32,7 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { private var _analyzed: Boolean = false /** - * Marks this plan as already analyzed. This should only be called by CheckAnalysis. + * Marks this plan as already analyzed. This should only be called by [[CheckAnalysis]]. */ private[catalyst] def setAnalyzed(): Unit = { _analyzed = true } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala index c2e62e739776..d1c6b50536cd 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala @@ -26,7 +26,8 @@ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.CalendarInterval /** - * Test basic expression parsing. If a type of expression is supported it should be tested here. + * Test basic expression parsing. + * If the type of an expression is supported it should be tested here. * * Please note that some of the expressions test don't have to be sound expressions, only their * structure needs to be valid. Unsound expressions should be caught by the Analyzer or diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index ae0703513cf4..43de2de7e709 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -84,8 +84,8 @@ class TypedColumn[-T, U]( } /** - * Gives the TypedColumn a name (alias). - * If the current TypedColumn has metadata associated with it, this metadata will be propagated + * Gives the [[TypedColumn]] a name (alias). + * If the current `TypedColumn` has metadata associated with it, this metadata will be propagated * to the new column. * * @group expr_ops @@ -99,16 +99,14 @@ class TypedColumn[-T, U]( /** * A column that will be computed based on the data in a `DataFrame`. * - * A new column is constructed based on the input columns present in a dataframe: + * A new column can be constructed based on the input columns present in a DataFrame: * * {{{ - * df("columnName") // On a specific DataFrame. + * df("columnName") // On a specific `df` DataFrame. * col("columnName") // A generic column no yet associated with a DataFrame. * col("columnName.field") // Extracting a struct field * col("`a.column.with.dots`") // Escape `.` in column names. * $"columnName" // Scala short hand for a named column. - * expr("a + 1") // A column that is constructed from a parsed SQL Expression. - * lit("abc") // A column that produces a literal (constant) value. * }}} * * [[Column]] objects can be composed to form complex expressions: @@ -118,7 +116,7 @@ class TypedColumn[-T, U]( * $"a" === $"b" * }}} * - * @note The internal Catalyst expression can be accessed via "expr", but this method is for + * @note The internal Catalyst expression can be accessed via [[expr]], but this method is for * debugging purposes only and can change in any future Spark releases. * * @groupname java_expr_ops Java-specific expression operators @@ -1100,7 +1098,7 @@ class Column(val expr: Expression) extends Logging { def asc_nulls_last: Column = withExpr { SortOrder(expr, Ascending, NullsLast, Set.empty) } /** - * Prints the expression to the console for debugging purpose. + * Prints the expression to the console for debugging purposes. * * @group df_ops * @since 1.3.0 @@ -1154,8 +1152,8 @@ class Column(val expr: Expression) extends Logging { * {{{ * val w = Window.partitionBy("name").orderBy("id") * df.select( - * sum("price").over(w.rangeBetween(Long.MinValue, 2)), - * avg("price").over(w.rowsBetween(0, 4)) + * sum("price").over(w.rangeBetween(Window.unboundedPreceding, 2)), + * avg("price").over(w.rowsBetween(Window.currentRow, 4)) * ) * }}} * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DatasetHolder.scala b/sql/core/src/main/scala/org/apache/spark/sql/DatasetHolder.scala index 18bccee98f61..582d4a3670b8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DatasetHolder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DatasetHolder.scala @@ -24,7 +24,8 @@ import org.apache.spark.annotation.InterfaceStability * * To use this, import implicit conversions in SQL: * {{{ - * import sqlContext.implicits._ + * val spark: SparkSession = ... + * import spark.implicits._ * }}} * * @since 1.6.0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index a97297892b5e..b60499253c42 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -60,7 +60,7 @@ import org.apache.spark.util.Utils * The builder can also be used to create a new session: * * {{{ - * SparkSession.builder() + * SparkSession.builder * .master("local") * .appName("Word Count") * .config("spark.some.config.option", "some-value") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/databases.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/databases.scala index e5a6a5f60b8a..470c736da98b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/databases.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/databases.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.types.StringType /** * A command for users to list the databases/schemas. - * If a databasePattern is supplied then the databases that only matches the + * If a databasePattern is supplied then the databases that only match the * pattern would be listed. * The syntax of using this command in SQL is: * {{{ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Source.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Source.scala index 75ffe90f2bb7..311942f6dbd8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Source.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Source.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.types.StructType * monotonically increasing notion of progress that can be represented as an [[Offset]]. Spark * will regularly query each [[Source]] to see if any more data is available. */ -trait Source { +trait Source { /** Returns the schema of the data from this source */ def schema: StructType diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala index f3cf3052ea3e..00053485e614 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala @@ -113,7 +113,7 @@ object Window { * Creates a [[WindowSpec]] with the frame boundaries defined, * from `start` (inclusive) to `end` (inclusive). * - * Both `start` and `end` are relative positions from the current row. For example, "0" means + * Both `start` and `end` are positions relative to the current row. For example, "0" means * "current row", while "-1" means the row before the current row, and "5" means the fifth row * after the current row. * @@ -131,9 +131,9 @@ object Window { * import org.apache.spark.sql.expressions.Window * val df = Seq((1, "a"), (1, "a"), (2, "a"), (1, "b"), (2, "b"), (3, "b")) * .toDF("id", "category") - * df.withColumn("sum", - * sum('id) over Window.partitionBy('category).orderBy('id).rowsBetween(0,1)) - * .show() + * val byCategoryOrderedById = + * Window.partitionBy('category).orderBy('id).rowsBetween(Window.currentRow, 1) + * df.withColumn("sum", sum('id) over byCategoryOrderedById).show() * * +---+--------+---+ * | id|category|sum| @@ -150,7 +150,7 @@ object Window { * @param start boundary start, inclusive. The frame is unbounded if this is * the minimum long value (`Window.unboundedPreceding`). * @param end boundary end, inclusive. The frame is unbounded if this is the - * maximum long value (`Window.unboundedFollowing`). + * maximum long value (`Window.unboundedFollowing`). * @since 2.1.0 */ // Note: when updating the doc for this method, also update WindowSpec.rowsBetween. @@ -162,7 +162,7 @@ object Window { * Creates a [[WindowSpec]] with the frame boundaries defined, * from `start` (inclusive) to `end` (inclusive). * - * Both `start` and `end` are relative from the current row. For example, "0" means "current row", + * Both `start` and `end` are relative to the current row. For example, "0" means "current row", * while "-1" means one off before the current row, and "5" means the five off after the * current row. * @@ -183,9 +183,9 @@ object Window { * import org.apache.spark.sql.expressions.Window * val df = Seq((1, "a"), (1, "a"), (2, "a"), (1, "b"), (2, "b"), (3, "b")) * .toDF("id", "category") - * df.withColumn("sum", - * sum('id) over Window.partitionBy('category).orderBy('id).rangeBetween(0,1)) - * .show() + * val byCategoryOrderedById = + * Window.partitionBy('category).orderBy('id).rowsBetween(Window.currentRow, 1) + * df.withColumn("sum", sum('id) over byCategoryOrderedById).show() * * +---+--------+---+ * | id|category|sum| @@ -202,7 +202,7 @@ object Window { * @param start boundary start, inclusive. The frame is unbounded if this is * the minimum long value (`Window.unboundedPreceding`). * @param end boundary end, inclusive. The frame is unbounded if this is the - * maximum long value (`Window.unboundedFollowing`). + * maximum long value (`Window.unboundedFollowing`). * @since 2.1.0 */ // Note: when updating the doc for this method, also update WindowSpec.rangeBetween. @@ -221,7 +221,8 @@ object Window { * * {{{ * // PARTITION BY country ORDER BY date ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW - * Window.partitionBy("country").orderBy("date").rowsBetween(Long.MinValue, 0) + * Window.partitionBy("country").orderBy("date") + * .rowsBetween(Window.unboundedPreceding, Window.currentRow) * * // PARTITION BY country ORDER BY date ROWS BETWEEN 3 PRECEDING AND 3 FOLLOWING * Window.partitionBy("country").orderBy("date").rowsBetween(-3, 3) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala index de7d7a177275..6279d48c94de 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala @@ -86,7 +86,7 @@ class WindowSpec private[sql]( * after the current row. * * We recommend users use `Window.unboundedPreceding`, `Window.unboundedFollowing`, - * and `[Window.currentRow` to specify special boundary values, rather than using integral + * and `Window.currentRow` to specify special boundary values, rather than using integral * values directly. * * A row based boundary is based on the position of the row within the partition. @@ -99,9 +99,9 @@ class WindowSpec private[sql]( * import org.apache.spark.sql.expressions.Window * val df = Seq((1, "a"), (1, "a"), (2, "a"), (1, "b"), (2, "b"), (3, "b")) * .toDF("id", "category") - * df.withColumn("sum", - * sum('id) over Window.partitionBy('category).orderBy('id).rowsBetween(0,1)) - * .show() + * val byCategoryOrderedById = + * Window.partitionBy('category).orderBy('id).rowsBetween(Window.currentRow, 1) + * df.withColumn("sum", sum('id) over byCategoryOrderedById).show() * * +---+--------+---+ * | id|category|sum| @@ -118,7 +118,7 @@ class WindowSpec private[sql]( * @param start boundary start, inclusive. The frame is unbounded if this is * the minimum long value (`Window.unboundedPreceding`). * @param end boundary end, inclusive. The frame is unbounded if this is the - * maximum long value (`Window.unboundedFollowing`). + * maximum long value (`Window.unboundedFollowing`). * @since 1.4.0 */ // Note: when updating the doc for this method, also update Window.rowsBetween. @@ -134,7 +134,7 @@ class WindowSpec private[sql]( * current row. * * We recommend users use `Window.unboundedPreceding`, `Window.unboundedFollowing`, - * and `[Window.currentRow` to specify special boundary values, rather than using integral + * and `Window.currentRow` to specify special boundary values, rather than using integral * values directly. * * A range based boundary is based on the actual value of the ORDER BY @@ -150,9 +150,9 @@ class WindowSpec private[sql]( * import org.apache.spark.sql.expressions.Window * val df = Seq((1, "a"), (1, "a"), (2, "a"), (1, "b"), (2, "b"), (3, "b")) * .toDF("id", "category") - * df.withColumn("sum", - * sum('id) over Window.partitionBy('category).orderBy('id).rangeBetween(0,1)) - * .show() + * val byCategoryOrderedById = + * Window.partitionBy('category).orderBy('id).rangeBetween(Window.currentRow, 1) + * df.withColumn("sum", sum('id) over byCategoryOrderedById).show() * * +---+--------+---+ * | id|category|sum| @@ -169,7 +169,7 @@ class WindowSpec private[sql]( * @param start boundary start, inclusive. The frame is unbounded if this is * the minimum long value (`Window.unboundedPreceding`). * @param end boundary end, inclusive. The frame is unbounded if this is the - * maximum long value (`Window.unboundedFollowing`). + * maximum long value (`Window.unboundedFollowing`). * @since 1.4.0 */ // Note: when updating the doc for this method, also update Window.rangeBetween. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 0f9203065ef0..f07e04368389 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -2968,7 +2968,7 @@ object functions { * * @param e a string column containing JSON data. * @param schema the schema to use when parsing the json string - * @param options options to control how the json is parsed. accepts the same options and the + * @param options options to control how the json is parsed. Accepts the same options as the * json data source. * * @group collection_funcs diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala index 8048c2ba2c2e..2f3dfa05e9ef 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.hive.client.HiveClient import org.apache.spark.sql.internal.{BaseSessionStateBuilder, SessionResourceLoader, SessionState} /** - * Builder that produces a Hive aware [[SessionState]]. + * Builder that produces a Hive-aware `SessionState`. */ @Experimental @InterfaceStability.Unstable diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/InputInfoTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/InputInfoTracker.scala index 8e1a09061843..639ac6de4f5d 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/InputInfoTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/InputInfoTracker.scala @@ -66,7 +66,7 @@ private[streaming] class InputInfoTracker(ssc: StreamingContext) extends Logging new mutable.HashMap[Int, StreamInputInfo]()) if (inputInfos.contains(inputInfo.inputStreamId)) { - throw new IllegalStateException(s"Input stream ${inputInfo.inputStreamId} for batch" + + throw new IllegalStateException(s"Input stream ${inputInfo.inputStreamId} for batch " + s"$batchTime is already added into InputInfoTracker, this is an illegal state") } inputInfos += ((inputInfo.inputStreamId, inputInfo)) From 258bff2c3f54490ddca898e276029db9adf575d9 Mon Sep 17 00:00:00 2001 From: samelamin Date: Thu, 30 Mar 2017 16:08:26 +0100 Subject: [PATCH 173/512] [SPARK-19999] Workaround JDK-8165231 to identify PPC64 architectures as supporting unaligned access java.nio.Bits.unaligned() does not return true for the ppc64le arch. see https://bugs.openjdk.java.net/browse/JDK-8165231 ## What changes were proposed in this pull request? check architecture ## How was this patch tested? unit test Author: samelamin Author: samelamin Closes #17472 from samelamin/SPARK-19999. --- .../org/apache/spark/unsafe/Platform.java | 28 +++++++++++-------- 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java index f13c24ae5e01..1321b8318115 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java @@ -46,18 +46,22 @@ public final class Platform { private static final boolean unaligned; static { boolean _unaligned; - // use reflection to access unaligned field - try { - Class bitsClass = - Class.forName("java.nio.Bits", false, ClassLoader.getSystemClassLoader()); - Method unalignedMethod = bitsClass.getDeclaredMethod("unaligned"); - unalignedMethod.setAccessible(true); - _unaligned = Boolean.TRUE.equals(unalignedMethod.invoke(null)); - } catch (Throwable t) { - // We at least know x86 and x64 support unaligned access. - String arch = System.getProperty("os.arch", ""); - //noinspection DynamicRegexReplaceableByCompiledPattern - _unaligned = arch.matches("^(i[3-6]86|x86(_64)?|x64|amd64|aarch64)$"); + String arch = System.getProperty("os.arch", ""); + if (arch.equals("ppc64le") || arch.equals("ppc64")) { + // Since java.nio.Bits.unaligned() doesn't return true on ppc (See JDK-8165231), but ppc64 and ppc64le support it + _unaligned = true; + } else { + try { + Class bitsClass = + Class.forName("java.nio.Bits", false, ClassLoader.getSystemClassLoader()); + Method unalignedMethod = bitsClass.getDeclaredMethod("unaligned"); + unalignedMethod.setAccessible(true); + _unaligned = Boolean.TRUE.equals(unalignedMethod.invoke(null)); + } catch (Throwable t) { + // We at least know x86 and x64 support unaligned access. + //noinspection DynamicRegexReplaceableByCompiledPattern + _unaligned = arch.matches("^(i[3-6]86|x86(_64)?|x64|amd64|aarch64)$"); + } } unaligned = _unaligned; } From e9d268f63e7308486739aa56ece02815bfb432d6 Mon Sep 17 00:00:00 2001 From: Kent Yao Date: Thu, 30 Mar 2017 16:11:03 +0100 Subject: [PATCH 174/512] [SPARK-20096][SPARK SUBMIT][MINOR] Expose the right queue name not null if set by --conf or configure file MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? while submit apps with -v or --verbose, we can print the right queue name, but if we set a queue name with `spark.yarn.queue` by --conf or in the spark-default.conf, we just got `null` for the queue in Parsed arguments. ``` bin/spark-shell -v --conf spark.yarn.queue=thequeue Using properties file: /home/hadoop/spark-2.1.0-bin-apache-hdp2.7.3/conf/spark-defaults.conf .... Adding default property: spark.yarn.queue=default Parsed arguments: master yarn deployMode client ... queue null .... verbose true Spark properties used, including those specified through --conf and those from the properties file /home/hadoop/spark-2.1.0-bin-apache-hdp2.7.3/conf/spark-defaults.conf: spark.yarn.queue -> thequeue .... ``` ## How was this patch tested? ut and local verify Author: Kent Yao Closes #17430 from yaooqinn/SPARK-20096. --- .../apache/spark/deploy/SparkSubmitArguments.scala | 1 + .../org/apache/spark/deploy/SparkSubmitSuite.scala | 11 +++++++++++ 2 files changed, 12 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index 0614d80b60e1..0144fd1056ba 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -190,6 +190,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S .orNull numExecutors = Option(numExecutors) .getOrElse(sparkProperties.get("spark.executor.instances").orNull) + queue = Option(queue).orElse(sparkProperties.get("spark.yarn.queue")).orNull keytab = Option(keytab).orElse(sparkProperties.get("spark.yarn.keytab")).orNull principal = Option(principal).orElse(sparkProperties.get("spark.yarn.principal")).orNull diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index a591b98bca48..7c2ec01a03d0 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -148,6 +148,17 @@ class SparkSubmitSuite appArgs.childArgs should be (Seq("--master", "local", "some", "--weird", "args")) } + test("print the right queue name") { + val clArgs = Seq( + "--name", "myApp", + "--class", "Foo", + "--conf", "spark.yarn.queue=thequeue", + "userjar.jar") + val appArgs = new SparkSubmitArguments(clArgs) + appArgs.queue should be ("thequeue") + appArgs.toString should include ("thequeue") + } + test("specify deploy mode through configuration") { val clArgs = Seq( "--master", "yarn", From 669a11b61bc217a13217f1ef48d781329c45575e Mon Sep 17 00:00:00 2001 From: "Seigneurin, Alexis (CONT)" Date: Thu, 30 Mar 2017 16:12:17 +0100 Subject: [PATCH 175/512] [DOCS][MINOR] Fixed a few typos in the Structured Streaming documentation Fixed a few typos. There is one more I'm not sure of: ``` Append mode uses watermark to drop old aggregation state. But the output of a windowed aggregation is delayed the late threshold specified in `withWatermark()` as by the modes semantics, rows can be added to the Result Table only once after they are ``` Not sure how to change `is delayed the late threshold`. Author: Seigneurin, Alexis (CONT) Closes #17443 from aseigneurin/typos. --- docs/structured-streaming-programming-guide.md | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/docs/structured-streaming-programming-guide.md b/docs/structured-streaming-programming-guide.md index ff07ad11943b..b5cf9f164498 100644 --- a/docs/structured-streaming-programming-guide.md +++ b/docs/structured-streaming-programming-guide.md @@ -717,11 +717,11 @@ However, to run this query for days, it's necessary for the system to bound the intermediate in-memory state it accumulates. This means the system needs to know when an old aggregate can be dropped from the in-memory state because the application is not going to receive late data for that aggregate any more. To enable this, in Spark 2.1, we have introduced -**watermarking**, which let's the engine automatically track the current event time in the data and +**watermarking**, which lets the engine automatically track the current event time in the data and attempt to clean up old state accordingly. You can define the watermark of a query by -specifying the event time column and the threshold on how late the data is expected be in terms of +specifying the event time column and the threshold on how late the data is expected to be in terms of event time. For a specific window starting at time `T`, the engine will maintain state and allow late -data to be update the state until `(max event time seen by the engine - late threshold > T)`. +data to update the state until `(max event time seen by the engine - late threshold > T)`. In other words, late data within the threshold will be aggregated, but data later than the threshold will be dropped. Let's understand this with an example. We can easily define watermarking on the previous example using `withWatermark()` as shown below. @@ -792,7 +792,7 @@ This watermark lets the engine maintain intermediate state for additional 10 min data to be counted. For example, the data `(12:09, cat)` is out of order and late, and it falls in windows `12:05 - 12:15` and `12:10 - 12:20`. Since, it is still ahead of the watermark `12:04` in the trigger, the engine still maintains the intermediate counts as state and correctly updates the -counts of the related windows. However, when the watermark is updated to 12:11, the intermediate +counts of the related windows. However, when the watermark is updated to `12:11`, the intermediate state for window `(12:00 - 12:10)` is cleared, and all subsequent data (e.g. `(12:04, donkey)`) is considered "too late" and therefore ignored. Note that after every trigger, the updated counts (i.e. purple rows) are written to sink as the trigger output, as dictated by @@ -825,7 +825,7 @@ section for detailed explanation of the semantics of each output mode. same column as the timestamp column used in the aggregate. For example, `df.withWatermark("time", "1 min").groupBy("time2").count()` is invalid in Append output mode, as watermark is defined on a different column -as the aggregation column. +from the aggregation column. - `withWatermark` must be called before the aggregation for the watermark details to be used. For example, `df.groupBy("time").count().withWatermark("time", "1 min")` is invalid in Append @@ -909,7 +909,7 @@ track of all the data received in the stream. This is therefore fundamentally ha efficiently. ## Starting Streaming Queries -Once you have defined the final result DataFrame/Dataset, all that is left is for you start the streaming computation. To do that, you have to use the `DataStreamWriter` +Once you have defined the final result DataFrame/Dataset, all that is left is for you to start the streaming computation. To do that, you have to use the `DataStreamWriter` ([Scala](api/scala/index.html#org.apache.spark.sql.streaming.DataStreamWriter)/[Java](api/java/org/apache/spark/sql/streaming/DataStreamWriter.html)/[Python](api/python/pyspark.sql.html#pyspark.sql.streaming.DataStreamWriter) docs) returned through `Dataset.writeStream()`. You will have to specify one or more of the following in this interface. @@ -1396,15 +1396,15 @@ You can directly get the current status and metrics of an active query using `lastProgress()` returns a `StreamingQueryProgress` object in [Scala](api/scala/index.html#org.apache.spark.sql.streaming.StreamingQueryProgress) and [Java](api/java/org/apache/spark/sql/streaming/StreamingQueryProgress.html) -and an dictionary with the same fields in Python. It has all the information about +and a dictionary with the same fields in Python. It has all the information about the progress made in the last trigger of the stream - what data was processed, what were the processing rates, latencies, etc. There is also `streamingQuery.recentProgress` which returns an array of last few progresses. -In addition, `streamingQuery.status()` returns `StreamingQueryStatus` object +In addition, `streamingQuery.status()` returns a `StreamingQueryStatus` object in [Scala](api/scala/index.html#org.apache.spark.sql.streaming.StreamingQueryStatus) and [Java](api/java/org/apache/spark/sql/streaming/StreamingQueryStatus.html) -and an dictionary with the same fields in Python. It gives information about +and a dictionary with the same fields in Python. It gives information about what the query is immediately doing - is a trigger active, is data being processed, etc. Here are a few examples. From 5e00a5de14ae2d80471c6f38c30cc6fe63e05163 Mon Sep 17 00:00:00 2001 From: Denis Bolshakov Date: Thu, 30 Mar 2017 16:15:40 +0100 Subject: [PATCH 176/512] [SPARK-20127][CORE] few warning have been fixed which Intellij IDEA reported Intellij IDEA ## What changes were proposed in this pull request? Few changes related to Intellij IDEA inspection. ## How was this patch tested? Changes were tested by existing unit tests Author: Denis Bolshakov Closes #17458 from dbolshak/SPARK-20127. --- .../org/apache/spark/memory/TaskMemoryManager.java | 6 +----- .../org/apache/spark/status/api/v1/TaskSorting.java | 5 ++--- .../scala/org/apache/spark/io/CompressionCodec.scala | 3 +-- core/src/main/scala/org/apache/spark/ui/WebUI.scala | 2 +- .../apache/spark/ui/exec/ExecutorThreadDumpPage.scala | 2 +- .../scala/org/apache/spark/ui/exec/ExecutorsPage.scala | 3 +-- .../scala/org/apache/spark/ui/exec/ExecutorsTab.scala | 4 ++-- .../scala/org/apache/spark/ui/jobs/AllStagesPage.scala | 4 ++-- .../scala/org/apache/spark/ui/jobs/ExecutorTable.scala | 4 ++-- .../org/apache/spark/ui/jobs/JobProgressListener.scala | 4 ++-- .../scala/org/apache/spark/ui/jobs/StagePage.scala | 10 +++++----- .../scala/org/apache/spark/ui/jobs/StageTable.scala | 2 +- .../org/apache/spark/ui/storage/StoragePage.scala | 2 +- 13 files changed, 22 insertions(+), 29 deletions(-) diff --git a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java index 39fb3b249d73..aa0b37323132 100644 --- a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java +++ b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java @@ -155,11 +155,7 @@ public long acquireExecutionMemory(long required, MemoryConsumer consumer) { for (MemoryConsumer c: consumers) { if (c != consumer && c.getUsed() > 0 && c.getMode() == mode) { long key = c.getUsed(); - List list = sortedConsumers.get(key); - if (list == null) { - list = new ArrayList<>(1); - sortedConsumers.put(key, list); - } + List list = sortedConsumers.computeIfAbsent(key, k -> new ArrayList<>(1)); list.add(c); } } diff --git a/core/src/main/java/org/apache/spark/status/api/v1/TaskSorting.java b/core/src/main/java/org/apache/spark/status/api/v1/TaskSorting.java index 9307eb93a5b2..b38639e85481 100644 --- a/core/src/main/java/org/apache/spark/status/api/v1/TaskSorting.java +++ b/core/src/main/java/org/apache/spark/status/api/v1/TaskSorting.java @@ -19,6 +19,7 @@ import org.apache.spark.util.EnumUtil; +import java.util.Collections; import java.util.HashSet; import java.util.Set; @@ -30,9 +31,7 @@ public enum TaskSorting { private final Set alternateNames; TaskSorting(String... names) { alternateNames = new HashSet<>(); - for (String n: names) { - alternateNames.add(n); - } + Collections.addAll(alternateNames, names); } public static TaskSorting fromString(String str) { diff --git a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala index 2e991ce394c4..c216fe477fd1 100644 --- a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala +++ b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala @@ -71,8 +71,7 @@ private[spark] object CompressionCodec { val ctor = Utils.classForName(codecClass).getConstructor(classOf[SparkConf]) Some(ctor.newInstance(conf).asInstanceOf[CompressionCodec]) } catch { - case e: ClassNotFoundException => None - case e: IllegalArgumentException => None + case _: ClassNotFoundException | _: IllegalArgumentException => None } codec.getOrElse(throw new IllegalArgumentException(s"Codec [$codecName] is not available. " + s"Consider setting $configKey=$FALLBACK_COMPRESSION_CODEC")) diff --git a/core/src/main/scala/org/apache/spark/ui/WebUI.scala b/core/src/main/scala/org/apache/spark/ui/WebUI.scala index a9480cc220c8..8b75f5d8fe1a 100644 --- a/core/src/main/scala/org/apache/spark/ui/WebUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/WebUI.scala @@ -124,7 +124,7 @@ private[spark] abstract class WebUI( /** Bind to the HTTP server behind this web interface. */ def bind(): Unit = { - assert(!serverInfo.isDefined, s"Attempted to bind $className more than once!") + assert(serverInfo.isEmpty, s"Attempted to bind $className more than once!") try { val host = Option(conf.getenv("SPARK_LOCAL_IP")).getOrElse("0.0.0.0") serverInfo = Some(startJettyServer(host, port, sslOptions, handlers, conf, name)) diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala index c6a07445f2a3..dbcc6402bc30 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala @@ -49,7 +49,7 @@ private[ui] class ExecutorThreadDumpPage(parent: ExecutorsTab) extends WebUIPage }.map { thread => val threadId = thread.threadId val blockedBy = thread.blockedByThreadId match { - case Some(blockedByThreadId) => + case Some(_) =>
      Blocked by Thread {thread.blockedByThreadId} {thread.blockedByLock} diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala index 2d1691e55c42..d849ce76a9e3 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala @@ -48,7 +48,6 @@ private[ui] class ExecutorsPage( parent: ExecutorsTab, threadDumpEnabled: Boolean) extends WebUIPage("") { - private val listener = parent.listener def render(request: HttpServletRequest): Seq[Node] = { val content = @@ -59,7 +58,7 @@ private[ui] class ExecutorsPage( ++ } -
      ; + UIUtils.headerSparkPage("Executors", content, parent, useDataTables = true) } diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala index 8ae712f8ed32..03851293eb2f 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala @@ -64,7 +64,7 @@ private[ui] case class ExecutorTaskSummary( @DeveloperApi class ExecutorsListener(storageStatusListener: StorageStatusListener, conf: SparkConf) extends SparkListener { - var executorToTaskSummary = LinkedHashMap[String, ExecutorTaskSummary]() + val executorToTaskSummary = LinkedHashMap[String, ExecutorTaskSummary]() var executorEvents = new ListBuffer[SparkListenerEvent]() private val maxTimelineExecutors = conf.getInt("spark.ui.timeline.executors.maximum", 1000) @@ -137,7 +137,7 @@ class ExecutorsListener(storageStatusListener: StorageStatusListener, conf: Spar // could have failed half-way through. The correct fix would be to keep track of the // metrics added by each attempt, but this is much more complicated. return - case e: ExceptionFailure => + case _: ExceptionFailure => taskSummary.tasksFailed += 1 case _ => taskSummary.tasksComplete += 1 diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala index fe6ca1099e6b..2b0816e35747 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala @@ -34,9 +34,9 @@ private[ui] class AllStagesPage(parent: StagesTab) extends WebUIPage("") { listener.synchronized { val activeStages = listener.activeStages.values.toSeq val pendingStages = listener.pendingStages.values.toSeq - val completedStages = listener.completedStages.reverse.toSeq + val completedStages = listener.completedStages.reverse val numCompletedStages = listener.numCompletedStages - val failedStages = listener.failedStages.reverse.toSeq + val failedStages = listener.failedStages.reverse val numFailedStages = listener.numFailedStages val subPath = "stages" diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala index 52f41298a172..382a6f979f2e 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala @@ -133,9 +133,9 @@ private[ui] class ExecutorTable(stageId: Int, stageAttemptId: Int, parent: Stage {executorIdToAddress.getOrElse(k, "CANNOT FIND ADDRESS")} {UIUtils.formatDuration(v.taskTime)} - {v.failedTasks + v.succeededTasks + v.reasonToNumKilled.map(_._2).sum} + {v.failedTasks + v.succeededTasks + v.reasonToNumKilled.values.sum} {v.failedTasks} - {v.reasonToNumKilled.map(_._2).sum} + {v.reasonToNumKilled.values.sum} {v.succeededTasks} {if (stageData.hasInput) { diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala index 1cf03e1541d1..f78db5ab80d1 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala @@ -226,7 +226,7 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { trimJobsIfNecessary(completedJobs) jobData.status = JobExecutionStatus.SUCCEEDED numCompletedJobs += 1 - case JobFailed(exception) => + case JobFailed(_) => failedJobs += jobData trimJobsIfNecessary(failedJobs) jobData.status = JobExecutionStatus.FAILED @@ -284,7 +284,7 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { ) { jobData.numActiveStages -= 1 if (stage.failureReason.isEmpty) { - if (!stage.submissionTime.isEmpty) { + if (stage.submissionTime.isDefined) { jobData.completedStageIndices.add(stage.stageId) } } else { diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index ff17775008ac..19325a2dc916 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -142,7 +142,7 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { val allAccumulables = progressListener.stageIdToData((stageId, stageAttemptId)).accumulables val externalAccumulables = allAccumulables.values.filter { acc => !acc.internal } - val hasAccumulators = externalAccumulables.size > 0 + val hasAccumulators = externalAccumulables.nonEmpty val summary =
      @@ -339,7 +339,7 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { val validTasks = tasks.filter(t => t.taskInfo.status == "SUCCESS" && t.metrics.isDefined) val summaryTable: Option[Seq[Node]] = - if (validTasks.size == 0) { + if (validTasks.isEmpty) { None } else { @@ -786,8 +786,8 @@ private[ui] object StagePage { info: TaskInfo, metrics: TaskMetricsUIData, currentTime: Long): Long = { if (info.finished) { val totalExecutionTime = info.finishTime - info.launchTime - val executorOverhead = (metrics.executorDeserializeTime + - metrics.resultSerializationTime) + val executorOverhead = metrics.executorDeserializeTime + + metrics.resultSerializationTime math.max( 0, totalExecutionTime - metrics.executorRunTime - executorOverhead - @@ -872,7 +872,7 @@ private[ui] class TaskDataSource( // so that we can avoid creating duplicate contents during sorting the data private val data = tasks.map(taskRow).sorted(ordering(sortColumn, desc)) - private var _slicedTaskIds: Set[Long] = null + private var _slicedTaskIds: Set[Long] = _ override def dataSize: Int = data.size diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala index f4caad0f5871..256b726fa7ee 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala @@ -412,7 +412,7 @@ private[ui] class StageDataSource( // so that we can avoid creating duplicate contents during sorting the data private val data = stages.map(stageRow).sorted(ordering(sortColumn, desc)) - private var _slicedStageIds: Set[Int] = null + private var _slicedStageIds: Set[Int] = _ override def dataSize: Int = data.size diff --git a/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala b/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala index 76d7c6d414bc..aa84788f1df8 100644 --- a/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala @@ -151,7 +151,7 @@ private[ui] class StoragePage(parent: StorageTab) extends WebUIPage("") { /** Render a stream block */ private def streamBlockTableRow(block: (BlockId, Seq[BlockUIData])): Seq[Node] = { val replications = block._2 - assert(replications.size > 0) // This must be true because it's the result of "groupBy" + assert(replications.nonEmpty) // This must be true because it's the result of "groupBy" if (replications.size == 1) { streamBlockTableSubrow(block._1, replications.head, replications.size, true) } else { From c734fc504a3f6a3d3b0bd90ff54604b17df2b413 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 30 Mar 2017 13:36:36 -0700 Subject: [PATCH 177/512] [SPARK-20121][SQL] simplify NullPropagation with NullIntolerant ## What changes were proposed in this pull request? Instead of iterating all expressions that can return null for null inputs, we can just check `NullIntolerant`. ## How was this patch tested? existing tests Author: Wenchen Fan Closes #17450 from cloud-fan/null. --- .../sql/catalyst/expressions/arithmetic.scala | 18 +++--- .../expressions/complexTypeExtractors.scala | 8 +-- .../sql/catalyst/expressions/package.scala | 2 +- .../expressions/regexpExpressions.scala | 10 ++-- .../expressions/stringExpressions.scala | 15 ++--- .../sql/catalyst/optimizer/expressions.scala | 59 ++++++------------- 6 files changed, 39 insertions(+), 73 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 4870093e9250..f2b252259b89 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -113,7 +113,7 @@ case class Abs(child: Expression) protected override def nullSafeEval(input: Any): Any = numeric.abs(input) } -abstract class BinaryArithmetic extends BinaryOperator { +abstract class BinaryArithmetic extends BinaryOperator with NullIntolerant { override def dataType: DataType = left.dataType @@ -146,7 +146,7 @@ object BinaryArithmetic { > SELECT 1 _FUNC_ 2; 3 """) -case class Add(left: Expression, right: Expression) extends BinaryArithmetic with NullIntolerant { +case class Add(left: Expression, right: Expression) extends BinaryArithmetic { override def inputType: AbstractDataType = TypeCollection.NumericAndInterval @@ -182,8 +182,7 @@ case class Add(left: Expression, right: Expression) extends BinaryArithmetic wit > SELECT 2 _FUNC_ 1; 1 """) -case class Subtract(left: Expression, right: Expression) - extends BinaryArithmetic with NullIntolerant { +case class Subtract(left: Expression, right: Expression) extends BinaryArithmetic { override def inputType: AbstractDataType = TypeCollection.NumericAndInterval @@ -219,8 +218,7 @@ case class Subtract(left: Expression, right: Expression) > SELECT 2 _FUNC_ 3; 6 """) -case class Multiply(left: Expression, right: Expression) - extends BinaryArithmetic with NullIntolerant { +case class Multiply(left: Expression, right: Expression) extends BinaryArithmetic { override def inputType: AbstractDataType = NumericType @@ -243,8 +241,7 @@ case class Multiply(left: Expression, right: Expression) 1.0 """) // scalastyle:on line.size.limit -case class Divide(left: Expression, right: Expression) - extends BinaryArithmetic with NullIntolerant { +case class Divide(left: Expression, right: Expression) extends BinaryArithmetic { override def inputType: AbstractDataType = TypeCollection(DoubleType, DecimalType) @@ -324,8 +321,7 @@ case class Divide(left: Expression, right: Expression) > SELECT 2 _FUNC_ 1.8; 0.2 """) -case class Remainder(left: Expression, right: Expression) - extends BinaryArithmetic with NullIntolerant { +case class Remainder(left: Expression, right: Expression) extends BinaryArithmetic { override def inputType: AbstractDataType = NumericType @@ -412,7 +408,7 @@ case class Remainder(left: Expression, right: Expression) > SELECT _FUNC_(-10, 3); 2 """) -case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic with NullIntolerant { +case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic { override def toString: String = s"pmod($left, $right)" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index 0c256c3d890f..de1594d119e1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -104,7 +104,7 @@ trait ExtractValue extends Expression * For example, when get field `yEAr` from ``, we should pass in `yEAr`. */ case class GetStructField(child: Expression, ordinal: Int, name: Option[String] = None) - extends UnaryExpression with ExtractValue { + extends UnaryExpression with ExtractValue with NullIntolerant { lazy val childSchema = child.dataType.asInstanceOf[StructType] @@ -152,7 +152,7 @@ case class GetArrayStructFields( field: StructField, ordinal: Int, numFields: Int, - containsNull: Boolean) extends UnaryExpression with ExtractValue { + containsNull: Boolean) extends UnaryExpression with ExtractValue with NullIntolerant { override def dataType: DataType = ArrayType(field.dataType, containsNull) override def toString: String = s"$child.${field.name}" @@ -213,7 +213,7 @@ case class GetArrayStructFields( * We need to do type checking here as `ordinal` expression maybe unresolved. */ case class GetArrayItem(child: Expression, ordinal: Expression) - extends BinaryExpression with ExpectsInputTypes with ExtractValue { + extends BinaryExpression with ExpectsInputTypes with ExtractValue with NullIntolerant { // We have done type checking for child in `ExtractValue`, so only need to check the `ordinal`. override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType, IntegralType) @@ -260,7 +260,7 @@ case class GetArrayItem(child: Expression, ordinal: Expression) * We need to do type checking here as `key` expression maybe unresolved. */ case class GetMapValue(child: Expression, key: Expression) - extends BinaryExpression with ImplicitCastInputTypes with ExtractValue { + extends BinaryExpression with ImplicitCastInputTypes with ExtractValue with NullIntolerant { private def keyType = child.dataType.asInstanceOf[MapType].keyType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala index 1b00c9e79da2..4c8b177237d2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala @@ -138,5 +138,5 @@ package object expressions { * input will result in null output). We will use this information during constructing IsNotNull * constraints. */ - trait NullIntolerant + trait NullIntolerant extends Expression } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index 4896a6225aa8..b23da537be72 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -27,8 +27,8 @@ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String -trait StringRegexExpression extends ImplicitCastInputTypes { - self: BinaryExpression => +abstract class StringRegexExpression extends BinaryExpression + with ImplicitCastInputTypes with NullIntolerant { def escape(v: String): String def matches(regex: Pattern, str: String): Boolean @@ -69,8 +69,7 @@ trait StringRegexExpression extends ImplicitCastInputTypes { */ @ExpressionDescription( usage = "str _FUNC_ pattern - Returns true if `str` matches `pattern`, or false otherwise.") -case class Like(left: Expression, right: Expression) - extends BinaryExpression with StringRegexExpression { +case class Like(left: Expression, right: Expression) extends StringRegexExpression { override def escape(v: String): String = StringUtils.escapeLikeRegex(v) @@ -122,8 +121,7 @@ case class Like(left: Expression, right: Expression) @ExpressionDescription( usage = "str _FUNC_ regexp - Returns true if `str` matches `regexp`, or false otherwise.") -case class RLike(left: Expression, right: Expression) - extends BinaryExpression with StringRegexExpression { +case class RLike(left: Expression, right: Expression) extends StringRegexExpression { override def escape(v: String): String = v override def matches(regex: Pattern, str: String): Boolean = regex.matcher(str).find(0) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 908aa44f81c9..5598a146997c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -297,8 +297,8 @@ case class Lower(child: Expression) extends UnaryExpression with String2StringEx } /** A base trait for functions that compare two strings, returning a boolean. */ -trait StringPredicate extends Predicate with ImplicitCastInputTypes { - self: BinaryExpression => +abstract class StringPredicate extends BinaryExpression + with Predicate with ImplicitCastInputTypes with NullIntolerant { def compare(l: UTF8String, r: UTF8String): Boolean @@ -313,8 +313,7 @@ trait StringPredicate extends Predicate with ImplicitCastInputTypes { /** * A function that returns true if the string `left` contains the string `right`. */ -case class Contains(left: Expression, right: Expression) - extends BinaryExpression with StringPredicate { +case class Contains(left: Expression, right: Expression) extends StringPredicate { override def compare(l: UTF8String, r: UTF8String): Boolean = l.contains(r) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, (c1, c2) => s"($c1).contains($c2)") @@ -324,8 +323,7 @@ case class Contains(left: Expression, right: Expression) /** * A function that returns true if the string `left` starts with the string `right`. */ -case class StartsWith(left: Expression, right: Expression) - extends BinaryExpression with StringPredicate { +case class StartsWith(left: Expression, right: Expression) extends StringPredicate { override def compare(l: UTF8String, r: UTF8String): Boolean = l.startsWith(r) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, (c1, c2) => s"($c1).startsWith($c2)") @@ -335,8 +333,7 @@ case class StartsWith(left: Expression, right: Expression) /** * A function that returns true if the string `left` ends with the string `right`. */ -case class EndsWith(left: Expression, right: Expression) - extends BinaryExpression with StringPredicate { +case class EndsWith(left: Expression, right: Expression) extends StringPredicate { override def compare(l: UTF8String, r: UTF8String): Boolean = l.endsWith(r) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, (c1, c2) => s"($c1).endsWith($c2)") @@ -1122,7 +1119,7 @@ case class StringSpace(child: Expression) """) // scalastyle:on line.size.limit case class Substring(str: Expression, pos: Expression, len: Expression) - extends TernaryExpression with ImplicitCastInputTypes { + extends TernaryExpression with ImplicitCastInputTypes with NullIntolerant { def this(str: Expression, pos: Expression) = { this(str, pos, Literal(Integer.MAX_VALUE)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 21d1cd593262..33039127f16c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -347,35 +347,30 @@ object LikeSimplification extends Rule[LogicalPlan] { * Null value propagation from bottom to top of the expression tree. */ case class NullPropagation(conf: CatalystConf) extends Rule[LogicalPlan] { - private def nonNullLiteral(e: Expression): Boolean = e match { - case Literal(null, _) => false - case _ => true + private def isNullLiteral(e: Expression): Boolean = e match { + case Literal(null, _) => true + case _ => false } def apply(plan: LogicalPlan): LogicalPlan = plan transform { case q: LogicalPlan => q transformExpressionsUp { case e @ WindowExpression(Cast(Literal(0L, _), _, _), _) => Cast(Literal(0L), e.dataType, Option(conf.sessionLocalTimeZone)) - case e @ AggregateExpression(Count(exprs), _, _, _) if !exprs.exists(nonNullLiteral) => + case e @ AggregateExpression(Count(exprs), _, _, _) if exprs.forall(isNullLiteral) => Cast(Literal(0L), e.dataType, Option(conf.sessionLocalTimeZone)) - case e @ IsNull(c) if !c.nullable => Literal.create(false, BooleanType) - case e @ IsNotNull(c) if !c.nullable => Literal.create(true, BooleanType) - case e @ GetArrayItem(Literal(null, _), _) => Literal.create(null, e.dataType) - case e @ GetArrayItem(_, Literal(null, _)) => Literal.create(null, e.dataType) - case e @ GetMapValue(Literal(null, _), _) => Literal.create(null, e.dataType) - case e @ GetMapValue(_, Literal(null, _)) => Literal.create(null, e.dataType) - case e @ GetStructField(Literal(null, _), _, _) => Literal.create(null, e.dataType) - case e @ GetArrayStructFields(Literal(null, _), _, _, _, _) => - Literal.create(null, e.dataType) - case e @ EqualNullSafe(Literal(null, _), r) => IsNull(r) - case e @ EqualNullSafe(l, Literal(null, _)) => IsNull(l) case ae @ AggregateExpression(Count(exprs), _, false, _) if !exprs.exists(_.nullable) => // This rule should be only triggered when isDistinct field is false. ae.copy(aggregateFunction = Count(Literal(1))) + case IsNull(c) if !c.nullable => Literal.create(false, BooleanType) + case IsNotNull(c) if !c.nullable => Literal.create(true, BooleanType) + + case EqualNullSafe(Literal(null, _), r) => IsNull(r) + case EqualNullSafe(l, Literal(null, _)) => IsNull(l) + // For Coalesce, remove null literals. case e @ Coalesce(children) => - val newChildren = children.filter(nonNullLiteral) + val newChildren = children.filterNot(isNullLiteral) if (newChildren.isEmpty) { Literal.create(null, e.dataType) } else if (newChildren.length == 1) { @@ -384,33 +379,13 @@ case class NullPropagation(conf: CatalystConf) extends Rule[LogicalPlan] { Coalesce(newChildren) } - case e @ Substring(Literal(null, _), _, _) => Literal.create(null, e.dataType) - case e @ Substring(_, Literal(null, _), _) => Literal.create(null, e.dataType) - case e @ Substring(_, _, Literal(null, _)) => Literal.create(null, e.dataType) - - // Put exceptional cases above if any - case e @ BinaryArithmetic(Literal(null, _), _) => Literal.create(null, e.dataType) - case e @ BinaryArithmetic(_, Literal(null, _)) => Literal.create(null, e.dataType) - - case e @ BinaryComparison(Literal(null, _), _) => Literal.create(null, e.dataType) - case e @ BinaryComparison(_, Literal(null, _)) => Literal.create(null, e.dataType) - - case e: StringRegexExpression => e.children match { - case Literal(null, _) :: right :: Nil => Literal.create(null, e.dataType) - case left :: Literal(null, _) :: Nil => Literal.create(null, e.dataType) - case _ => e - } - - case e: StringPredicate => e.children match { - case Literal(null, _) :: right :: Nil => Literal.create(null, e.dataType) - case left :: Literal(null, _) :: Nil => Literal.create(null, e.dataType) - case _ => e - } - - // If the value expression is NULL then transform the In expression to - // Literal(null) - case In(Literal(null, _), list) => Literal.create(null, BooleanType) + // If the value expression is NULL then transform the In expression to null literal. + case In(Literal(null, _), _) => Literal.create(null, BooleanType) + // Non-leaf NullIntolerant expressions will return null, if at least one of its children is + // a null literal. + case e: NullIntolerant if e.children.exists(isNullLiteral) => + Literal.create(null, e.dataType) } } } From a8a765b3f302c078cb9519c4a17912cd38b9680c Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 30 Mar 2017 23:09:33 -0700 Subject: [PATCH 178/512] [SPARK-20151][SQL] Account for partition pruning in scan metadataTime metrics ## What changes were proposed in this pull request? After SPARK-20136, we report metadata timing metrics in scan operator. However, that timing metric doesn't include one of the most important part of metadata, which is partition pruning. This patch adds that time measurement to the scan metrics. ## How was this patch tested? N/A - I tried adding a test in SQLMetricsSuite but it was extremely convoluted to the point that I'm not sure if this is worth it. Author: Reynold Xin Closes #17476 from rxin/SPARK-20151. --- .../spark/sql/execution/DataSourceScanExec.scala | 5 +++-- .../sql/execution/datasources/CatalogFileIndex.scala | 7 +++++-- .../spark/sql/execution/datasources/FileIndex.scala | 10 ++++++++++ 3 files changed, 18 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index 239151495f4b..2fa660c4d5e0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -172,12 +172,13 @@ case class FileSourceScanExec( } @transient private lazy val selectedPartitions: Seq[PartitionDirectory] = { + val optimizerMetadataTimeNs = relation.location.metadataOpsTimeNs.getOrElse(0L) val startTime = System.nanoTime() val ret = relation.location.listFiles(partitionFilters, dataFilters) - val timeTaken = (System.nanoTime() - startTime) / 1000 / 1000 + val timeTakenMs = ((System.nanoTime() - startTime) + optimizerMetadataTimeNs) / 1000 / 1000 metrics("numFiles").add(ret.map(_.files.size.toLong).sum) - metrics("metadataTime").add(timeTaken) + metrics("metadataTime").add(timeTakenMs) val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CatalogFileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CatalogFileIndex.scala index db0254f8d558..4046396d0e61 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CatalogFileIndex.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CatalogFileIndex.scala @@ -69,6 +69,7 @@ class CatalogFileIndex( */ def filterPartitions(filters: Seq[Expression]): InMemoryFileIndex = { if (table.partitionColumnNames.nonEmpty) { + val startTime = System.nanoTime() val selectedPartitions = sparkSession.sessionState.catalog.listPartitionsByFilter( table.identifier, filters) val partitions = selectedPartitions.map { p => @@ -79,8 +80,9 @@ class CatalogFileIndex( path.makeQualified(fs.getUri, fs.getWorkingDirectory)) } val partitionSpec = PartitionSpec(partitionSchema, partitions) + val timeNs = System.nanoTime() - startTime new PrunedInMemoryFileIndex( - sparkSession, new Path(baseLocation.get), fileStatusCache, partitionSpec) + sparkSession, new Path(baseLocation.get), fileStatusCache, partitionSpec, Option(timeNs)) } else { new InMemoryFileIndex( sparkSession, rootPaths, table.storage.properties, partitionSchema = None) @@ -111,7 +113,8 @@ private class PrunedInMemoryFileIndex( sparkSession: SparkSession, tableBasePath: Path, fileStatusCache: FileStatusCache, - override val partitionSpec: PartitionSpec) + override val partitionSpec: PartitionSpec, + override val metadataOpsTimeNs: Option[Long]) extends InMemoryFileIndex( sparkSession, partitionSpec.partitions.map(_.path), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileIndex.scala index 6b99d38fe572..094a66a2820f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileIndex.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileIndex.scala @@ -72,4 +72,14 @@ trait FileIndex { /** Schema of the partitioning columns, or the empty schema if the table is not partitioned. */ def partitionSchema: StructType + + /** + * Returns an optional metadata operation time, in nanoseconds, for listing files. + * + * We do file listing in query optimization (in order to get the proper statistics) and we want + * to account for file listing time in physical execution (as metrics). To do that, we save the + * file listing time in some implementations and physical execution calls it in this method + * to update the metrics. + */ + def metadataOpsTimeNs: Option[Long] = None } From 254877c2f04414c70d92fa0a00c0ecee1d73aba7 Mon Sep 17 00:00:00 2001 From: Kunal Khamar Date: Fri, 31 Mar 2017 09:17:22 -0700 Subject: [PATCH 179/512] [SPARK-20164][SQL] AnalysisException not tolerant of null query plan. ## What changes were proposed in this pull request? The query plan in an `AnalysisException` may be `null` when an `AnalysisException` object is serialized and then deserialized, since `plan` is marked `transient`. Or when someone throws an `AnalysisException` with a null query plan (which should not happen). `def getMessage` is not tolerant of this and throws a `NullPointerException`, leading to loss of information about the original exception. The fix is to add a `null` check in `getMessage`. ## How was this patch tested? - Unit test Author: Kunal Khamar Closes #17486 from kunalkhamar/spark-20164. --- .../scala/org/apache/spark/sql/AnalysisException.scala | 2 +- .../test/scala/org/apache/spark/sql/SQLQuerySuite.scala | 8 ++++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala index ff8576157305..50ee6cd4085e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala @@ -43,7 +43,7 @@ class AnalysisException protected[sql] ( } override def getMessage: String = { - val planAnnotation = plan.map(p => s";\n$p").getOrElse("") + val planAnnotation = Option(plan).flatten.map(p => s";\n$p").getOrElse("") getSimpleMessage + planAnnotation } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index d9e0196c5795..0dd9296a3f0f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -2598,4 +2598,12 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } assert(!jobStarted.get(), "Command should not trigger a Spark job.") } + + test("SPARK-20164: AnalysisException should be tolerant to null query plan") { + try { + throw new AnalysisException("", None, None, plan = null) + } catch { + case ae: AnalysisException => assert(ae.plan == null && ae.getMessage == ae.getSimpleMessage) + } + } } From c4c03eed67c05a78dc8944f6119ea708d6b955be Mon Sep 17 00:00:00 2001 From: Ryan Blue Date: Fri, 31 Mar 2017 09:42:49 -0700 Subject: [PATCH 180/512] [SPARK-20084][CORE] Remove internal.metrics.updatedBlockStatuses from history files. ## What changes were proposed in this pull request? Remove accumulator updates for internal.metrics.updatedBlockStatuses from SparkListenerTaskEnd entries in the history file. These can cause history files to grow to hundreds of GB because the value of the accumulator contains all tracked blocks. ## How was this patch tested? Current History UI tests cover use of the history file. Author: Ryan Blue Closes #17412 from rdblue/SPARK-20084-remove-block-accumulator-info. --- .../org/apache/spark/util/JsonProtocol.scala | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index 2cb88919c8c8..1d2cb7acefa3 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -264,8 +264,7 @@ private[spark] object JsonProtocol { ("Submission Time" -> submissionTime) ~ ("Completion Time" -> completionTime) ~ ("Failure Reason" -> failureReason) ~ - ("Accumulables" -> JArray( - stageInfo.accumulables.values.map(accumulableInfoToJson).toList)) + ("Accumulables" -> accumulablesToJson(stageInfo.accumulables.values)) } def taskInfoToJson(taskInfo: TaskInfo): JValue = { @@ -281,7 +280,15 @@ private[spark] object JsonProtocol { ("Finish Time" -> taskInfo.finishTime) ~ ("Failed" -> taskInfo.failed) ~ ("Killed" -> taskInfo.killed) ~ - ("Accumulables" -> JArray(taskInfo.accumulables.toList.map(accumulableInfoToJson))) + ("Accumulables" -> accumulablesToJson(taskInfo.accumulables)) + } + + private lazy val accumulableBlacklist = Set("internal.metrics.updatedBlockStatuses") + + def accumulablesToJson(accumulables: Traversable[AccumulableInfo]): JArray = { + JArray(accumulables + .filterNot(_.name.exists(accumulableBlacklist.contains)) + .toList.map(accumulableInfoToJson)) } def accumulableInfoToJson(accumulableInfo: AccumulableInfo): JValue = { @@ -376,7 +383,7 @@ private[spark] object JsonProtocol { ("Message" -> fetchFailed.message) case exceptionFailure: ExceptionFailure => val stackTrace = stackTraceToJson(exceptionFailure.stackTrace) - val accumUpdates = JArray(exceptionFailure.accumUpdates.map(accumulableInfoToJson).toList) + val accumUpdates = accumulablesToJson(exceptionFailure.accumUpdates) ("Class Name" -> exceptionFailure.className) ~ ("Description" -> exceptionFailure.description) ~ ("Stack Trace" -> stackTrace) ~ From b2349e6a00d569851f0ca91a60e9299306208e92 Mon Sep 17 00:00:00 2001 From: Xiao Li Date: Sat, 1 Apr 2017 00:56:18 +0800 Subject: [PATCH 181/512] [SPARK-20160][SQL] Move ParquetConversions and OrcConversions Out Of HiveSessionCatalog ### What changes were proposed in this pull request? `ParquetConversions` and `OrcConversions` should be treated as regular `Analyzer` rules. It is not reasonable to be part of `HiveSessionCatalog`. This PR also combines two rules `ParquetConversions` and `OrcConversions` to build a new rule `RelationConversions `. After moving these two rules out of HiveSessionCatalog, the next step is to clean up, rename and move `HiveMetastoreCatalog` because it is not related to the hive package any more. ### How was this patch tested? The existing test cases Author: Xiao Li Closes #17484 from gatorsmile/cleanup. --- .../spark/sql/hive/HiveMetastoreCatalog.scala | 96 ++----------------- .../spark/sql/hive/HiveSessionCatalog.scala | 25 +---- .../sql/hive/HiveSessionStateBuilder.scala | 3 +- .../spark/sql/hive/HiveStrategies.scala | 56 ++++++++++- .../hive/JavaMetastoreDataSourcesSuite.java | 3 +- .../spark/sql/hive/StatisticsSuite.scala | 4 +- .../apache/spark/sql/hive/parquetSuites.scala | 5 +- 7 files changed, 70 insertions(+), 122 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 305bd007c93f..10f432570e94 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -28,11 +28,7 @@ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.{QualifiedTableName, TableIdentifier} import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.rules._ -import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources._ -import org.apache.spark.sql.execution.datasources.parquet.{ParquetFileFormat, ParquetOptions} -import org.apache.spark.sql.hive.orc.OrcFileFormat import org.apache.spark.sql.internal.SQLConf.HiveCaseSensitiveInferenceMode._ import org.apache.spark.sql.types._ @@ -48,14 +44,6 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log private def tableRelationCache = sparkSession.sessionState.catalog.tableRelationCache import HiveMetastoreCatalog._ - private def getCurrentDatabase: String = sessionState.catalog.getCurrentDatabase - - def getQualifiedTableName(tableIdent: TableIdentifier): QualifiedTableName = { - QualifiedTableName( - tableIdent.database.getOrElse(getCurrentDatabase).toLowerCase, - tableIdent.table.toLowerCase) - } - /** These locks guard against multiple attempts to instantiate a table, which wastes memory. */ private val tableCreationLocks = Striped.lazyWeakLock(100) @@ -68,11 +56,12 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log } } - def hiveDefaultTableFilePath(tableIdent: TableIdentifier): String = { - // Code based on: hiveWarehouse.getTablePath(currentDatabase, tableName) - val QualifiedTableName(dbName, tblName) = getQualifiedTableName(tableIdent) - val dbLocation = sparkSession.sharedState.externalCatalog.getDatabase(dbName).locationUri - new Path(new Path(dbLocation), tblName).toString + // For testing only + private[hive] def getCachedDataSourceTable(table: TableIdentifier): LogicalPlan = { + val key = QualifiedTableName( + table.database.getOrElse(sessionState.catalog.getCurrentDatabase).toLowerCase, + table.table.toLowerCase) + tableRelationCache.getIfPresent(key) } private def getCached( @@ -122,7 +111,7 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log } } - private def convertToLogicalRelation( + def convertToLogicalRelation( relation: CatalogRelation, options: Map[String, String], fileFormatClass: Class[_ <: FileFormat], @@ -273,78 +262,9 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log case NonFatal(ex) => logWarning(s"Unable to save case-sensitive schema for table ${identifier.unquotedString}", ex) } - - /** - * When scanning or writing to non-partitioned Metastore Parquet tables, convert them to Parquet - * data source relations for better performance. - */ - object ParquetConversions extends Rule[LogicalPlan] { - private def shouldConvertMetastoreParquet(relation: CatalogRelation): Boolean = { - relation.tableMeta.storage.serde.getOrElse("").toLowerCase.contains("parquet") && - sessionState.conf.getConf(HiveUtils.CONVERT_METASTORE_PARQUET) - } - - private def convertToParquetRelation(relation: CatalogRelation): LogicalRelation = { - val fileFormatClass = classOf[ParquetFileFormat] - val mergeSchema = sessionState.conf.getConf( - HiveUtils.CONVERT_METASTORE_PARQUET_WITH_SCHEMA_MERGING) - val options = Map(ParquetOptions.MERGE_SCHEMA -> mergeSchema.toString) - - convertToLogicalRelation(relation, options, fileFormatClass, "parquet") - } - - override def apply(plan: LogicalPlan): LogicalPlan = { - plan transformUp { - // Write path - case InsertIntoTable(r: CatalogRelation, partition, query, overwrite, ifNotExists) - // Inserting into partitioned table is not supported in Parquet data source (yet). - if query.resolved && DDLUtils.isHiveTable(r.tableMeta) && - !r.isPartitioned && shouldConvertMetastoreParquet(r) => - InsertIntoTable(convertToParquetRelation(r), partition, query, overwrite, ifNotExists) - - // Read path - case relation: CatalogRelation if DDLUtils.isHiveTable(relation.tableMeta) && - shouldConvertMetastoreParquet(relation) => - convertToParquetRelation(relation) - } - } - } - - /** - * When scanning Metastore ORC tables, convert them to ORC data source relations - * for better performance. - */ - object OrcConversions extends Rule[LogicalPlan] { - private def shouldConvertMetastoreOrc(relation: CatalogRelation): Boolean = { - relation.tableMeta.storage.serde.getOrElse("").toLowerCase.contains("orc") && - sessionState.conf.getConf(HiveUtils.CONVERT_METASTORE_ORC) - } - - private def convertToOrcRelation(relation: CatalogRelation): LogicalRelation = { - val fileFormatClass = classOf[OrcFileFormat] - val options = Map[String, String]() - - convertToLogicalRelation(relation, options, fileFormatClass, "orc") - } - - override def apply(plan: LogicalPlan): LogicalPlan = { - plan transformUp { - // Write path - case InsertIntoTable(r: CatalogRelation, partition, query, overwrite, ifNotExists) - // Inserting into partitioned table is not supported in Orc data source (yet). - if query.resolved && DDLUtils.isHiveTable(r.tableMeta) && - !r.isPartitioned && shouldConvertMetastoreOrc(r) => - InsertIntoTable(convertToOrcRelation(r), partition, query, overwrite, ifNotExists) - - // Read path - case relation: CatalogRelation if DDLUtils.isHiveTable(relation.tableMeta) && - shouldConvertMetastoreOrc(relation) => - convertToOrcRelation(relation) - } - } - } } + private[hive] object HiveMetastoreCatalog { def mergeWithMetastoreSchema( metastoreSchema: StructType, diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala index 2cc20a791d80..9e3eb2dd8234 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala @@ -26,14 +26,12 @@ import org.apache.hadoop.hive.ql.exec.{FunctionRegistry => HiveFunctionRegistry} import org.apache.hadoop.hive.ql.udf.generic.{AbstractGenericUDAFResolver, GenericUDF, GenericUDTF} import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} +import org.apache.spark.sql.catalyst.FunctionIdentifier import org.apache.spark.sql.catalyst.analysis.FunctionRegistry import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder import org.apache.spark.sql.catalyst.catalog.{FunctionResourceLoader, GlobalTempViewManager, SessionCatalog} import org.apache.spark.sql.catalyst.expressions.{Cast, Expression, ExpressionInfo} import org.apache.spark.sql.catalyst.parser.ParserInterface -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.hive.HiveShim.HiveFunctionWrapper import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DecimalType, DoubleType} @@ -43,7 +41,7 @@ import org.apache.spark.util.Utils private[sql] class HiveSessionCatalog( externalCatalog: HiveExternalCatalog, globalTempViewManager: GlobalTempViewManager, - private val metastoreCatalog: HiveMetastoreCatalog, + val metastoreCatalog: HiveMetastoreCatalog, functionRegistry: FunctionRegistry, conf: SQLConf, hadoopConf: Configuration, @@ -58,25 +56,6 @@ private[sql] class HiveSessionCatalog( parser, functionResourceLoader) { - // ---------------------------------------------------------------- - // | Methods and fields for interacting with HiveMetastoreCatalog | - // ---------------------------------------------------------------- - - // These 2 rules must be run before all other DDL post-hoc resolution rules, i.e. - // `PreprocessTableCreation`, `PreprocessTableInsertion`, `DataSourceAnalysis` and `HiveAnalysis`. - val ParquetConversions: Rule[LogicalPlan] = metastoreCatalog.ParquetConversions - val OrcConversions: Rule[LogicalPlan] = metastoreCatalog.OrcConversions - - def hiveDefaultTableFilePath(name: TableIdentifier): String = { - metastoreCatalog.hiveDefaultTableFilePath(name) - } - - // For testing only - private[hive] def getCachedDataSourceTable(table: TableIdentifier): LogicalPlan = { - val key = metastoreCatalog.getQualifiedTableName(table) - tableRelationCache.getIfPresent(key) - } - override def makeFunctionBuilder(funcName: String, className: String): FunctionBuilder = { makeFunctionBuilder(funcName, Utils.classForName(className)) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala index 2f3dfa05e9ef..9d3b31f39c0f 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala @@ -75,8 +75,7 @@ class HiveSessionStateBuilder(session: SparkSession, parentState: Option[Session override val postHocResolutionRules: Seq[Rule[LogicalPlan]] = new DetermineTableStats(session) +: - catalog.ParquetConversions +: - catalog.OrcConversions +: + RelationConversions(conf, catalog) +: PreprocessTableCreation(session) +: PreprocessTableInsertion(conf) +: DataSourceAnalysis(conf) +: diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index b5ce027d51e7..0465e9c031e2 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.hive import java.io.IOException -import java.net.URI import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.hive.common.StatsSetupConst @@ -31,9 +30,11 @@ import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.command.{CreateTableCommand, DDLUtils} -import org.apache.spark.sql.execution.datasources.CreateTable +import org.apache.spark.sql.execution.datasources.{CreateTable, LogicalRelation} +import org.apache.spark.sql.execution.datasources.parquet.{ParquetFileFormat, ParquetOptions} import org.apache.spark.sql.hive.execution._ -import org.apache.spark.sql.internal.HiveSerDe +import org.apache.spark.sql.hive.orc.OrcFileFormat +import org.apache.spark.sql.internal.{HiveSerDe, SQLConf} /** @@ -170,6 +171,55 @@ object HiveAnalysis extends Rule[LogicalPlan] { } } +/** + * Relation conversion from metastore relations to data source relations for better performance + * + * - When writing to non-partitioned Hive-serde Parquet/Orc tables + * - When scanning Hive-serde Parquet/ORC tables + * + * This rule must be run before all other DDL post-hoc resolution rules, i.e. + * `PreprocessTableCreation`, `PreprocessTableInsertion`, `DataSourceAnalysis` and `HiveAnalysis`. + */ +case class RelationConversions( + conf: SQLConf, + sessionCatalog: HiveSessionCatalog) extends Rule[LogicalPlan] { + private def isConvertible(relation: CatalogRelation): Boolean = { + (relation.tableMeta.storage.serde.getOrElse("").toLowerCase.contains("parquet") && + conf.getConf(HiveUtils.CONVERT_METASTORE_PARQUET)) || + (relation.tableMeta.storage.serde.getOrElse("").toLowerCase.contains("orc") && + conf.getConf(HiveUtils.CONVERT_METASTORE_ORC)) + } + + private def convert(relation: CatalogRelation): LogicalRelation = { + if (relation.tableMeta.storage.serde.getOrElse("").toLowerCase.contains("parquet")) { + val options = Map(ParquetOptions.MERGE_SCHEMA -> + conf.getConf(HiveUtils.CONVERT_METASTORE_PARQUET_WITH_SCHEMA_MERGING).toString) + sessionCatalog.metastoreCatalog + .convertToLogicalRelation(relation, options, classOf[ParquetFileFormat], "parquet") + } else { + val options = Map[String, String]() + sessionCatalog.metastoreCatalog + .convertToLogicalRelation(relation, options, classOf[OrcFileFormat], "orc") + } + } + + override def apply(plan: LogicalPlan): LogicalPlan = { + plan transformUp { + // Write path + case InsertIntoTable(r: CatalogRelation, partition, query, overwrite, ifNotExists) + // Inserting into partitioned table is not supported in Parquet/Orc data source (yet). + if query.resolved && DDLUtils.isHiveTable(r.tableMeta) && + !r.isPartitioned && isConvertible(r) => + InsertIntoTable(convert(r), partition, query, overwrite, ifNotExists) + + // Read path + case relation: CatalogRelation + if DDLUtils.isHiveTable(relation.tableMeta) && isConvertible(relation) => + convert(relation) + } + } +} + private[hive] trait HiveStrategies { // Possibly being too clever with types here... or not clever enough. self: SparkPlanner => diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java index 0b157a45e6e0..25bd4d0017bd 100644 --- a/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java @@ -72,8 +72,7 @@ public void setUp() throws IOException { path.delete(); } HiveSessionCatalog catalog = (HiveSessionCatalog) sqlContext.sessionState().catalog(); - hiveManagedPath = new Path( - catalog.hiveDefaultTableFilePath(new TableIdentifier("javaSavedTable"))); + hiveManagedPath = new Path(catalog.defaultTablePath(new TableIdentifier("javaSavedTable"))); fs = hiveManagedPath.getFileSystem(sc.hadoopConfiguration()); fs.delete(hiveManagedPath, true); diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index 962998ea6fb6..3191b9975fbf 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -413,7 +413,7 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto } // Table lookup will make the table cached. spark.table(tableIndent) - statsBeforeUpdate = catalog.getCachedDataSourceTable(tableIndent) + statsBeforeUpdate = catalog.metastoreCatalog.getCachedDataSourceTable(tableIndent) .asInstanceOf[LogicalRelation].catalogTable.get.stats.get sql(s"INSERT INTO $tableName SELECT 2") @@ -423,7 +423,7 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS") } spark.table(tableIndent) - statsAfterUpdate = catalog.getCachedDataSourceTable(tableIndent) + statsAfterUpdate = catalog.metastoreCatalog.getCachedDataSourceTable(tableIndent) .asInstanceOf[LogicalRelation].catalogTable.get.stats.get } (statsBeforeUpdate, statsAfterUpdate) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala index 9fc2923bb6fd..23f21e6b9931 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala @@ -449,8 +449,9 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { } } - private def getCachedDataSourceTable(id: TableIdentifier): LogicalPlan = { - sessionState.catalog.asInstanceOf[HiveSessionCatalog].getCachedDataSourceTable(id) + private def getCachedDataSourceTable(table: TableIdentifier): LogicalPlan = { + sessionState.catalog.asInstanceOf[HiveSessionCatalog].metastoreCatalog + .getCachedDataSourceTable(table) } test("Caching converted data source Parquet Relations") { From 567a50acfb0ae26bd430c290348886d494963696 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Fri, 31 Mar 2017 10:58:43 -0700 Subject: [PATCH 182/512] [SPARK-20165][SS] Resolve state encoder's deserializer in driver in FlatMapGroupsWithStateExec ## What changes were proposed in this pull request? - Encoder's deserializer must be resolved at the driver where the class is defined. Otherwise there are corner cases using nested classes where resolving at the executor can fail. - Fixed flaky test related to processing time timeout. The flakiness is caused because the test thread (that adds data to memory source) has a race condition with the streaming query thread. When testing the manual clock, the goal is to add data and increment clock together atomically, such that a trigger sees new data AND updated clock simultaneously (both or none). This fix adds additional synchronization in when adding data; it makes sure that the streaming query thread is waiting on the manual clock to be incremented (so no batch is currently running) before adding data. - Added`testQuietly` on some tests that generate a lot of error logs. ## How was this patch tested? Multiple runs on existing unit tests Author: Tathagata Das Closes #17488 from tdas/SPARK-20165. --- .../FlatMapGroupsWithStateExec.scala | 28 +++++++++++-------- .../sql/streaming/FileStreamSourceSuite.scala | 4 +-- .../FlatMapGroupsWithStateSuite.scala | 7 +++-- .../spark/sql/streaming/StreamSuite.scala | 2 +- .../spark/sql/streaming/StreamTest.scala | 23 +++++++++++++-- .../sql/streaming/StreamingQuerySuite.scala | 2 +- 6 files changed, 45 insertions(+), 21 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala index c7262ea97200..e42df5dd61c7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala @@ -68,6 +68,20 @@ case class FlatMapGroupsWithStateExec( val encSchemaAttribs = stateEncoder.schema.toAttributes if (isTimeoutEnabled) encSchemaAttribs :+ timestampTimeoutAttribute else encSchemaAttribs } + // Get the serializer for the state, taking into account whether we need to save timestamps + private val stateSerializer = { + val encoderSerializer = stateEncoder.namedExpressions + if (isTimeoutEnabled) { + encoderSerializer :+ Literal(GroupStateImpl.NO_TIMESTAMP) + } else { + encoderSerializer + } + } + // Get the deserializer for the state. Note that this must be done in the driver, as + // resolving and binding of deserializer expressions to the encoded type can be safely done + // only in the driver. + private val stateDeserializer = stateEncoder.resolveAndBind().deserializer + /** Distribute by grouping attributes */ override def requiredChildDistribution: Seq[Distribution] = @@ -139,19 +153,9 @@ case class FlatMapGroupsWithStateExec( ObjectOperator.deserializeRowToObject(valueDeserializer, dataAttributes) private val getOutputRow = ObjectOperator.wrapObjectToRow(outputObjAttr.dataType) - // Converter for translating state rows to Java objects + // Converters for translating state between rows and Java objects private val getStateObjFromRow = ObjectOperator.deserializeRowToObject( - stateEncoder.resolveAndBind().deserializer, stateAttributes) - - // Converter for translating state Java objects to rows - private val stateSerializer = { - val encoderSerializer = stateEncoder.namedExpressions - if (isTimeoutEnabled) { - encoderSerializer :+ Literal(GroupStateImpl.NO_TIMESTAMP) - } else { - encoderSerializer - } - } + stateDeserializer, stateAttributes) private val getStateRowFromObj = ObjectOperator.serializeObjectToRow(stateSerializer) // Index of the additional metadata fields in the state row diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala index f705da3d6a70..171877abe6e9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala @@ -909,7 +909,7 @@ class FileStreamSourceSuite extends FileStreamSourceTest { } } - test("max files per trigger - incorrect values") { + testQuietly("max files per trigger - incorrect values") { val testTable = "maxFilesPerTrigger_test" withTable(testTable) { withTempDir { case src => @@ -1326,7 +1326,7 @@ class FileStreamSourceStressTestSuite extends FileStreamSourceTest { import testImplicits._ - test("file source stress test") { + testQuietly("file source stress test") { val src = Utils.createTempDir(namePrefix = "streaming.src") val tmp = Utils.createTempDir(namePrefix = "streaming.tmp") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala index a00a1a582a97..c8e31e3ca2e0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala @@ -21,6 +21,8 @@ import java.sql.Date import java.util.concurrent.ConcurrentHashMap import org.scalatest.BeforeAndAfterAll +import org.scalatest.concurrent.Eventually.eventually +import org.scalatest.concurrent.PatienceConfiguration.Timeout import org.apache.spark.SparkException import org.apache.spark.api.java.function.FlatMapGroupsWithStateFunction @@ -574,11 +576,10 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf assertNumStateRows(total = 1, updated = 2), StopStream, - StartStream(ProcessingTime("1 second"), triggerClock = clock), - AdvanceManualClock(10 * 1000), + StartStream(Trigger.ProcessingTime("1 second"), triggerClock = clock), AddData(inputData, "c"), - AdvanceManualClock(1 * 1000), + AdvanceManualClock(11 * 1000), CheckLastBatch(("b", "-1"), ("c", "1")), assertNumStateRows(total = 1, updated = 2), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala index 32920f6dfa22..388f15405e70 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala @@ -426,7 +426,7 @@ class StreamSuite extends StreamTest { CheckAnswer((1, 2), (2, 2), (3, 2))) } - test("recover from a Spark v2.1 checkpoint") { + testQuietly("recover from a Spark v2.1 checkpoint") { var inputData: MemoryStream[Int] = null var query: DataStreamWriter[Row] = null diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index 8cf179133681..951ff2ca0d68 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -488,8 +488,27 @@ trait StreamTest extends QueryTest with SharedSQLContext with Timeouts { case a: AddData => try { - // Add data and get the source where it was added, and the expected offset of the - // added data. + + // If the query is running with manual clock, then wait for the stream execution + // thread to start waiting for the clock to increment. This is needed so that we + // are adding data when there is no trigger that is active. This would ensure that + // the data gets deterministically added to the next batch triggered after the manual + // clock is incremented in following AdvanceManualClock. This avoid race conditions + // between the test thread and the stream execution thread in tests using manual + // clock. + if (currentStream != null && + currentStream.triggerClock.isInstanceOf[StreamManualClock]) { + val clock = currentStream.triggerClock.asInstanceOf[StreamManualClock] + eventually("Error while synchronizing with manual clock before adding data") { + if (currentStream.isActive) { + assert(clock.isStreamWaitingAt(clock.getTimeMillis())) + } + } + if (!currentStream.isActive) { + failTest("Query terminated while synchronizing with manual clock") + } + } + // Add data val queryToUse = Option(currentStream).orElse(Option(lastStream)) val (source, offset) = a.addData(queryToUse) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala index 3f41ecdb7ff6..1172531fe998 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala @@ -487,7 +487,7 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi } } - test("StreamingQuery should be Serializable but cannot be used in executors") { + testQuietly("StreamingQuery should be Serializable but cannot be used in executors") { def startQuery(ds: Dataset[Int], queryName: String): StreamingQuery = { ds.writeStream .queryName(queryName) From cf5963c961e7eba37bdd58658ed4dfff66ce3c72 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=83=AD=E5=B0=8F=E9=BE=99=2010207633?= Date: Sat, 1 Apr 2017 11:48:58 +0100 Subject: [PATCH 183/512] =?UTF-8?q?[SPARK-20177]=20Document=20about=20comp?= =?UTF-8?q?ression=20way=20has=20some=20little=20detail=20ch=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …anges. ## What changes were proposed in this pull request? Document compression way little detail changes. 1.spark.eventLog.compress add 'Compression will use spark.io.compression.codec.' 2.spark.broadcast.compress add 'Compression will use spark.io.compression.codec.' 3,spark.rdd.compress add 'Compression will use spark.io.compression.codec.' 4.spark.io.compression.codec add 'event log describe'. eg Through the documents, I don't know what is compression mode about 'event log'. ## How was this patch tested? manual tests Please review http://spark.apache.org/contributing.html before opening a pull request. Author: 郭小龙 10207633 Closes #17498 from guoxiaolongzte/SPARK-20177. --- docs/configuration.md | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/docs/configuration.md b/docs/configuration.md index a9753925407d..2687f542b8bd 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -639,6 +639,7 @@ Apart from these, the following properties are also available, and may be useful false Whether to compress logged events, if spark.eventLog.enabled is true. + Compression will use spark.io.compression.codec. @@ -773,14 +774,15 @@ Apart from these, the following properties are also available, and may be useful true Whether to compress broadcast variables before sending them. Generally a good idea. + Compression will use spark.io.compression.codec. spark.io.compression.codec lz4 - The codec used to compress internal data such as RDD partitions, broadcast variables and - shuffle outputs. By default, Spark provides three codecs: lz4, lzf, + The codec used to compress internal data such as RDD partitions, event log, broadcast variables + and shuffle outputs. By default, Spark provides three codecs: lz4, lzf, and snappy. You can also use fully qualified class names to specify the codec, e.g. org.apache.spark.io.LZ4CompressionCodec, @@ -881,6 +883,7 @@ Apart from these, the following properties are also available, and may be useful StorageLevel.MEMORY_ONLY_SER in Java and Scala or StorageLevel.MEMORY_ONLY in Python). Can save substantial space at the cost of some extra CPU time. + Compression will use spark.io.compression.codec. From 89d6822f722912d2b05571a95a539092091650b5 Mon Sep 17 00:00:00 2001 From: Xiao Li Date: Sat, 1 Apr 2017 20:43:13 +0800 Subject: [PATCH 184/512] [SPARK-19148][SQL][FOLLOW-UP] do not expose the external table concept in Catalog ### What changes were proposed in this pull request? After we renames `Catalog`.`createExternalTable` to `createTable` in the PR: https://github.com/apache/spark/pull/16528, we also need to deprecate the corresponding functions in `SQLContext`. ### How was this patch tested? N/A Author: Xiao Li Closes #17502 from gatorsmile/deprecateCreateExternalTable. --- .../org/apache/spark/sql/SQLContext.scala | 25 +++++++++++-------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 234ef2dffc6b..cc2983987eb9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql -import java.beans.BeanInfo import java.util.Properties import scala.collection.immutable @@ -527,8 +526,9 @@ class SQLContext private[sql](val sparkSession: SparkSession) * @group ddl_ops * @since 1.3.0 */ + @deprecated("use sparkSession.catalog.createTable instead.", "2.2.0") def createExternalTable(tableName: String, path: String): DataFrame = { - sparkSession.catalog.createExternalTable(tableName, path) + sparkSession.catalog.createTable(tableName, path) } /** @@ -538,11 +538,12 @@ class SQLContext private[sql](val sparkSession: SparkSession) * @group ddl_ops * @since 1.3.0 */ + @deprecated("use sparkSession.catalog.createTable instead.", "2.2.0") def createExternalTable( tableName: String, path: String, source: String): DataFrame = { - sparkSession.catalog.createExternalTable(tableName, path, source) + sparkSession.catalog.createTable(tableName, path, source) } /** @@ -552,11 +553,12 @@ class SQLContext private[sql](val sparkSession: SparkSession) * @group ddl_ops * @since 1.3.0 */ + @deprecated("use sparkSession.catalog.createTable instead.", "2.2.0") def createExternalTable( tableName: String, source: String, options: java.util.Map[String, String]): DataFrame = { - sparkSession.catalog.createExternalTable(tableName, source, options) + sparkSession.catalog.createTable(tableName, source, options) } /** @@ -567,11 +569,12 @@ class SQLContext private[sql](val sparkSession: SparkSession) * @group ddl_ops * @since 1.3.0 */ + @deprecated("use sparkSession.catalog.createTable instead.", "2.2.0") def createExternalTable( tableName: String, source: String, options: Map[String, String]): DataFrame = { - sparkSession.catalog.createExternalTable(tableName, source, options) + sparkSession.catalog.createTable(tableName, source, options) } /** @@ -581,12 +584,13 @@ class SQLContext private[sql](val sparkSession: SparkSession) * @group ddl_ops * @since 1.3.0 */ + @deprecated("use sparkSession.catalog.createTable instead.", "2.2.0") def createExternalTable( tableName: String, source: String, schema: StructType, options: java.util.Map[String, String]): DataFrame = { - sparkSession.catalog.createExternalTable(tableName, source, schema, options) + sparkSession.catalog.createTable(tableName, source, schema, options) } /** @@ -597,12 +601,13 @@ class SQLContext private[sql](val sparkSession: SparkSession) * @group ddl_ops * @since 1.3.0 */ + @deprecated("use sparkSession.catalog.createTable instead.", "2.2.0") def createExternalTable( tableName: String, source: String, schema: StructType, options: Map[String, String]): DataFrame = { - sparkSession.catalog.createExternalTable(tableName, source, schema, options) + sparkSession.catalog.createTable(tableName, source, schema, options) } /** @@ -1089,9 +1094,9 @@ object SQLContext { * method for internal use. */ private[sql] def beansToRows( - data: Iterator[_], - beanClass: Class[_], - attrs: Seq[AttributeReference]): Iterator[InternalRow] = { + data: Iterator[_], + beanClass: Class[_], + attrs: Seq[AttributeReference]): Iterator[InternalRow] = { val extractors = JavaTypeInference.getJavaBeanReadableProperties(beanClass).map(_.getReadMethod) val methodsToConverts = extractors.zip(attrs).map { case (e, attr) => From 2287f3d0b85730995bedc489a017de5700d6e1e4 Mon Sep 17 00:00:00 2001 From: wangzhenhua Date: Sat, 1 Apr 2017 22:19:08 +0800 Subject: [PATCH 185/512] [SPARK-20186][SQL] BroadcastHint should use child's stats ## What changes were proposed in this pull request? `BroadcastHint` should use child's statistics and set `isBroadcastable` to true. ## How was this patch tested? Added a new stats estimation test for `BroadcastHint`. Author: wangzhenhua Closes #17504 from wzhfy/broadcastHintEstimation. --- .../plans/logical/basicLogicalOperators.scala | 2 +- .../BasicStatsEstimationSuite.scala | 21 ++++++++++++++++++- 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 5cbf263d1ce4..19db42c80895 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -383,7 +383,7 @@ case class BroadcastHint(child: LogicalPlan) extends UnaryNode { // set isBroadcastable to true so the child will be broadcasted override def computeStats(conf: CatalystConf): Statistics = - super.computeStats(conf).copy(isBroadcastable = true) + child.stats(conf).copy(isBroadcastable = true) } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala index e5dc811c8b7d..0d92c1e35565 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala @@ -35,6 +35,23 @@ class BasicStatsEstimationSuite extends StatsEstimationTestBase { // row count * (overhead + column size) size = Some(10 * (8 + 4))) + test("BroadcastHint estimation") { + val filter = Filter(Literal(true), plan) + val filterStatsCboOn = Statistics(sizeInBytes = 10 * (8 +4), isBroadcastable = false, + rowCount = Some(10), attributeStats = AttributeMap(Seq(attribute -> colStat))) + val filterStatsCboOff = Statistics(sizeInBytes = 10 * (8 +4), isBroadcastable = false) + checkStats( + filter, + expectedStatsCboOn = filterStatsCboOn, + expectedStatsCboOff = filterStatsCboOff) + + val broadcastHint = BroadcastHint(filter) + checkStats( + broadcastHint, + expectedStatsCboOn = filterStatsCboOn.copy(isBroadcastable = true), + expectedStatsCboOff = filterStatsCboOff.copy(isBroadcastable = true)) + } + test("limit estimation: limit < child's rowCount") { val localLimit = LocalLimit(Literal(2), plan) val globalLimit = GlobalLimit(Literal(2), plan) @@ -97,8 +114,10 @@ class BasicStatsEstimationSuite extends StatsEstimationTestBase { plan: LogicalPlan, expectedStatsCboOn: Statistics, expectedStatsCboOff: Statistics): Unit = { - assert(plan.stats(conf.copy(cboEnabled = true)) == expectedStatsCboOn) // Invalidate statistics + plan.invalidateStatsCache() + assert(plan.stats(conf.copy(cboEnabled = true)) == expectedStatsCboOn) + plan.invalidateStatsCache() assert(plan.stats(conf.copy(cboEnabled = false)) == expectedStatsCboOff) } From d40cbb861898de881621d5053a468af570d72127 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Sun, 2 Apr 2017 07:26:49 -0700 Subject: [PATCH 186/512] [SPARK-20143][SQL] DataType.fromJson should throw an exception with better message ## What changes were proposed in this pull request? Currently, `DataType.fromJson` throws `scala.MatchError` or `java.util.NoSuchElementException` in some cases when the JSON input is invalid as below: ```scala DataType.fromJson(""""abcd"""") ``` ``` java.util.NoSuchElementException: key not found: abcd at ... ``` ```scala DataType.fromJson("""{"abcd":"a"}""") ``` ``` scala.MatchError: JObject(List((abcd,JString(a)))) (of class org.json4s.JsonAST$JObject) at ... ``` ```scala DataType.fromJson("""{"fields": [{"a":123}], "type": "struct"}""") ``` ``` scala.MatchError: JObject(List((a,JInt(123)))) (of class org.json4s.JsonAST$JObject) at ... ``` After this PR, ```scala DataType.fromJson(""""abcd"""") ``` ``` java.lang.IllegalArgumentException: Failed to convert the JSON string 'abcd' to a data type. at ... ``` ```scala DataType.fromJson("""{"abcd":"a"}""") ``` ``` java.lang.IllegalArgumentException: Failed to convert the JSON string '{"abcd":"a"}' to a data type. at ... ``` ```scala DataType.fromJson("""{"fields": [{"a":123}], "type": "struct"}""") at ... ``` ``` java.lang.IllegalArgumentException: Failed to convert the JSON string '{"a":123}' to a field. ``` ## How was this patch tested? Unit test added in `DataTypeSuite`. Author: hyukjinkwon Closes #17468 from HyukjinKwon/fromjson_exception. --- .../org/apache/spark/sql/types/DataType.scala | 12 +++++++- .../spark/sql/types/DataTypeSuite.scala | 28 +++++++++++++++++++ 2 files changed, 39 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala index 2642d9395ba8..26871259c6b6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -115,7 +115,10 @@ object DataType { name match { case "decimal" => DecimalType.USER_DEFAULT case FIXED_DECIMAL(precision, scale) => DecimalType(precision.toInt, scale.toInt) - case other => nonDecimalNameToType(other) + case other => nonDecimalNameToType.getOrElse( + other, + throw new IllegalArgumentException( + s"Failed to convert the JSON string '$name' to a data type.")) } } @@ -164,6 +167,10 @@ object DataType { ("sqlType", v: JValue), ("type", JString("udt"))) => new PythonUserDefinedType(parseDataType(v), pyClass, serialized) + + case other => + throw new IllegalArgumentException( + s"Failed to convert the JSON string '${compact(render(other))}' to a data type.") } private def parseStructField(json: JValue): StructField = json match { @@ -179,6 +186,9 @@ object DataType { ("nullable", JBool(nullable)), ("type", dataType: JValue)) => StructField(name, parseDataType(dataType), nullable) + case other => + throw new IllegalArgumentException( + s"Failed to convert the JSON string '${compact(render(other))}' to a field.") } protected[types] def buildFormattedString( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala index 05cb999af6a5..f078ef013387 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.types +import com.fasterxml.jackson.core.JsonParseException + import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.sql.catalyst.parser.CatalystSqlParser @@ -246,6 +248,32 @@ class DataTypeSuite extends SparkFunSuite { checkDataTypeFromJson(structType) checkDataTypeFromDDL(structType) + test("fromJson throws an exception when given type string is invalid") { + var message = intercept[IllegalArgumentException] { + DataType.fromJson(""""abcd"""") + }.getMessage + assert(message.contains( + "Failed to convert the JSON string 'abcd' to a data type.")) + + message = intercept[IllegalArgumentException] { + DataType.fromJson("""{"abcd":"a"}""") + }.getMessage + assert(message.contains( + """Failed to convert the JSON string '{"abcd":"a"}' to a data type""")) + + message = intercept[IllegalArgumentException] { + DataType.fromJson("""{"fields": [{"a":123}], "type": "struct"}""") + }.getMessage + assert(message.contains( + """Failed to convert the JSON string '{"a":123}' to a field.""")) + + // Malformed JSON string + message = intercept[JsonParseException] { + DataType.fromJson("abcd") + }.getMessage + assert(message.contains("Unrecognized token 'abcd'")) + } + def checkDefaultSize(dataType: DataType, expectedDefaultSize: Int): Unit = { test(s"Check the default size of $dataType") { assert(dataType.defaultSize === expectedDefaultSize) From 76de2d115364aa6a1fdaacdfae05f0c695c953b8 Mon Sep 17 00:00:00 2001 From: zuotingbing Date: Sun, 2 Apr 2017 15:31:13 +0100 Subject: [PATCH 187/512] =?UTF-8?q?[SPARK-20123][BUILD]=20SPARK=5FHOME=20v?= =?UTF-8?q?ariable=20might=20have=20spaces=20in=20it(e.g.=20$SPARK?= =?UTF-8?q?=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit JIRA Issue: https://issues.apache.org/jira/browse/SPARK-20123 ## What changes were proposed in this pull request? If $SPARK_HOME or $FWDIR variable contains spaces, then use "./dev/make-distribution.sh --name custom-spark --tgz -Psparkr -Phadoop-2.7 -Phive -Phive-thriftserver -Pmesos -Pyarn" build spark will failed. ## How was this patch tested? manual tests Author: zuotingbing Closes #17452 from zuotingbing/spark-bulid. --- R/check-cran.sh | 20 ++++++++++---------- R/create-docs.sh | 10 +++++----- R/create-rd.sh | 8 ++++---- R/install-dev.sh | 14 +++++++------- R/install-source-package.sh | 20 ++++++++++---------- dev/make-distribution.sh | 32 ++++++++++++++++---------------- 6 files changed, 52 insertions(+), 52 deletions(-) diff --git a/R/check-cran.sh b/R/check-cran.sh index a188b1448a67..22cc9c6b601f 100755 --- a/R/check-cran.sh +++ b/R/check-cran.sh @@ -20,18 +20,18 @@ set -o pipefail set -e -FWDIR="$(cd `dirname "${BASH_SOURCE[0]}"`; pwd)" -pushd $FWDIR > /dev/null +FWDIR="$(cd "`dirname "${BASH_SOURCE[0]}"`"; pwd)" +pushd "$FWDIR" > /dev/null -. $FWDIR/find-r.sh +. "$FWDIR/find-r.sh" # Install the package (this is required for code in vignettes to run when building it later) # Build the latest docs, but not vignettes, which is built with the package next -. $FWDIR/install-dev.sh +. "$FWDIR/install-dev.sh" # Build source package with vignettes SPARK_HOME="$(cd "${FWDIR}"/..; pwd)" -. "${SPARK_HOME}"/bin/load-spark-env.sh +. "${SPARK_HOME}/bin/load-spark-env.sh" if [ -f "${SPARK_HOME}/RELEASE" ]; then SPARK_JARS_DIR="${SPARK_HOME}/jars" else @@ -40,16 +40,16 @@ fi if [ -d "$SPARK_JARS_DIR" ]; then # Build a zip file containing the source package with vignettes - SPARK_HOME="${SPARK_HOME}" "$R_SCRIPT_PATH/"R CMD build $FWDIR/pkg + SPARK_HOME="${SPARK_HOME}" "$R_SCRIPT_PATH/R" CMD build "$FWDIR/pkg" find pkg/vignettes/. -not -name '.' -not -name '*.Rmd' -not -name '*.md' -not -name '*.pdf' -not -name '*.html' -delete else - echo "Error Spark JARs not found in $SPARK_HOME" + echo "Error Spark JARs not found in '$SPARK_HOME'" exit 1 fi # Run check as-cran. -VERSION=`grep Version $FWDIR/pkg/DESCRIPTION | awk '{print $NF}'` +VERSION=`grep Version "$FWDIR/pkg/DESCRIPTION" | awk '{print $NF}'` CRAN_CHECK_OPTIONS="--as-cran" @@ -67,10 +67,10 @@ echo "Running CRAN check with $CRAN_CHECK_OPTIONS options" if [ -n "$NO_TESTS" ] && [ -n "$NO_MANUAL" ] then - "$R_SCRIPT_PATH/"R CMD check $CRAN_CHECK_OPTIONS SparkR_"$VERSION".tar.gz + "$R_SCRIPT_PATH/R" CMD check $CRAN_CHECK_OPTIONS "SparkR_$VERSION.tar.gz" else # This will run tests and/or build vignettes, and require SPARK_HOME - SPARK_HOME="${SPARK_HOME}" "$R_SCRIPT_PATH/"R CMD check $CRAN_CHECK_OPTIONS SparkR_"$VERSION".tar.gz + SPARK_HOME="${SPARK_HOME}" "$R_SCRIPT_PATH/R" CMD check $CRAN_CHECK_OPTIONS "SparkR_$VERSION.tar.gz" fi popd > /dev/null diff --git a/R/create-docs.sh b/R/create-docs.sh index 6bef7e75e3bd..310dbc5fb50a 100755 --- a/R/create-docs.sh +++ b/R/create-docs.sh @@ -33,15 +33,15 @@ export FWDIR="$(cd "`dirname "${BASH_SOURCE[0]}"`"; pwd)" export SPARK_HOME="$(cd "`dirname "${BASH_SOURCE[0]}"`"/..; pwd)" # Required for setting SPARK_SCALA_VERSION -. "${SPARK_HOME}"/bin/load-spark-env.sh +. "${SPARK_HOME}/bin/load-spark-env.sh" echo "Using Scala $SPARK_SCALA_VERSION" -pushd $FWDIR > /dev/null -. $FWDIR/find-r.sh +pushd "$FWDIR" > /dev/null +. "$FWDIR/find-r.sh" # Install the package (this will also generate the Rd files) -. $FWDIR/install-dev.sh +. "$FWDIR/install-dev.sh" # Now create HTML files @@ -49,7 +49,7 @@ pushd $FWDIR > /dev/null mkdir -p pkg/html pushd pkg/html -"$R_SCRIPT_PATH/"Rscript -e 'libDir <- "../../lib"; library(SparkR, lib.loc=libDir); library(knitr); knit_rd("SparkR", links = tools::findHTMLlinks(paste(libDir, "SparkR", sep="/")))' +"$R_SCRIPT_PATH/Rscript" -e 'libDir <- "../../lib"; library(SparkR, lib.loc=libDir); library(knitr); knit_rd("SparkR", links = tools::findHTMLlinks(paste(libDir, "SparkR", sep="/")))' popd diff --git a/R/create-rd.sh b/R/create-rd.sh index d17e1617397d..ff622a41a46c 100755 --- a/R/create-rd.sh +++ b/R/create-rd.sh @@ -29,9 +29,9 @@ set -o pipefail set -e -FWDIR="$(cd `dirname "${BASH_SOURCE[0]}"`; pwd)" -pushd $FWDIR > /dev/null -. $FWDIR/find-r.sh +FWDIR="$(cd "`dirname "${BASH_SOURCE[0]}"`"; pwd)" +pushd "$FWDIR" > /dev/null +. "$FWDIR/find-r.sh" # Generate Rd files if devtools is installed -"$R_SCRIPT_PATH/"Rscript -e ' if("devtools" %in% rownames(installed.packages())) { library(devtools); devtools::document(pkg="./pkg", roclets=c("rd")) }' +"$R_SCRIPT_PATH/Rscript" -e ' if("devtools" %in% rownames(installed.packages())) { library(devtools); devtools::document(pkg="./pkg", roclets=c("rd")) }' diff --git a/R/install-dev.sh b/R/install-dev.sh index 45e641170581..d61355271830 100755 --- a/R/install-dev.sh +++ b/R/install-dev.sh @@ -29,21 +29,21 @@ set -o pipefail set -e -FWDIR="$(cd `dirname "${BASH_SOURCE[0]}"`; pwd)" +FWDIR="$(cd "`dirname "${BASH_SOURCE[0]}"`"; pwd)" LIB_DIR="$FWDIR/lib" -mkdir -p $LIB_DIR +mkdir -p "$LIB_DIR" -pushd $FWDIR > /dev/null -. $FWDIR/find-r.sh +pushd "$FWDIR" > /dev/null +. "$FWDIR/find-r.sh" -. $FWDIR/create-rd.sh +. "$FWDIR/create-rd.sh" # Install SparkR to $LIB_DIR -"$R_SCRIPT_PATH/"R CMD INSTALL --library=$LIB_DIR $FWDIR/pkg/ +"$R_SCRIPT_PATH/R" CMD INSTALL --library="$LIB_DIR" "$FWDIR/pkg/" # Zip the SparkR package so that it can be distributed to worker nodes on YARN -cd $LIB_DIR +cd "$LIB_DIR" jar cfM "$LIB_DIR/sparkr.zip" SparkR popd > /dev/null diff --git a/R/install-source-package.sh b/R/install-source-package.sh index c6e443c04e62..8de3569d1d48 100755 --- a/R/install-source-package.sh +++ b/R/install-source-package.sh @@ -29,28 +29,28 @@ set -o pipefail set -e -FWDIR="$(cd `dirname "${BASH_SOURCE[0]}"`; pwd)" -pushd $FWDIR > /dev/null -. $FWDIR/find-r.sh +FWDIR="$(cd "`dirname "${BASH_SOURCE[0]}"`"; pwd)" +pushd "$FWDIR" > /dev/null +. "$FWDIR/find-r.sh" if [ -z "$VERSION" ]; then - VERSION=`grep Version $FWDIR/pkg/DESCRIPTION | awk '{print $NF}'` + VERSION=`grep Version "$FWDIR/pkg/DESCRIPTION" | awk '{print $NF}'` fi -if [ ! -f "$FWDIR"/SparkR_"$VERSION".tar.gz ]; then - echo -e "R source package file $FWDIR/SparkR_$VERSION.tar.gz is not found." +if [ ! -f "$FWDIR/SparkR_$VERSION.tar.gz" ]; then + echo -e "R source package file '$FWDIR/SparkR_$VERSION.tar.gz' is not found." echo -e "Please build R source package with check-cran.sh" exit -1; fi echo "Removing lib path and installing from source package" LIB_DIR="$FWDIR/lib" -rm -rf $LIB_DIR -mkdir -p $LIB_DIR -"$R_SCRIPT_PATH/"R CMD INSTALL SparkR_"$VERSION".tar.gz --library=$LIB_DIR +rm -rf "$LIB_DIR" +mkdir -p "$LIB_DIR" +"$R_SCRIPT_PATH/R" CMD INSTALL "SparkR_$VERSION.tar.gz" --library="$LIB_DIR" # Zip the SparkR package so that it can be distributed to worker nodes on YARN -pushd $LIB_DIR > /dev/null +pushd "$LIB_DIR" > /dev/null jar cfM "$LIB_DIR/sparkr.zip" SparkR popd > /dev/null diff --git a/dev/make-distribution.sh b/dev/make-distribution.sh index 769cbda4fe34..48a824499acb 100755 --- a/dev/make-distribution.sh +++ b/dev/make-distribution.sh @@ -140,7 +140,7 @@ echo "Spark version is $VERSION" if [ "$MAKE_TGZ" == "true" ]; then echo "Making spark-$VERSION-bin-$NAME.tgz" else - echo "Making distribution for Spark $VERSION in $DISTDIR..." + echo "Making distribution for Spark $VERSION in '$DISTDIR'..." fi # Build uber fat JAR @@ -170,7 +170,7 @@ cp "$SPARK_HOME"/assembly/target/scala*/jars/* "$DISTDIR/jars/" # Only create the yarn directory if the yarn artifacts were build. if [ -f "$SPARK_HOME"/common/network-yarn/target/scala*/spark-*-yarn-shuffle.jar ]; then - mkdir "$DISTDIR"/yarn + mkdir "$DISTDIR/yarn" cp "$SPARK_HOME"/common/network-yarn/target/scala*/spark-*-yarn-shuffle.jar "$DISTDIR/yarn" fi @@ -179,7 +179,7 @@ mkdir -p "$DISTDIR/examples/jars" cp "$SPARK_HOME"/examples/target/scala*/jars/* "$DISTDIR/examples/jars" # Deduplicate jars that have already been packaged as part of the main Spark dependencies. -for f in "$DISTDIR/examples/jars/"*; do +for f in "$DISTDIR"/examples/jars/*; do name=$(basename "$f") if [ -f "$DISTDIR/jars/$name" ]; then rm "$DISTDIR/examples/jars/$name" @@ -188,14 +188,14 @@ done # Copy example sources (needed for python and SQL) mkdir -p "$DISTDIR/examples/src/main" -cp -r "$SPARK_HOME"/examples/src/main "$DISTDIR/examples/src/" +cp -r "$SPARK_HOME/examples/src/main" "$DISTDIR/examples/src/" # Copy license and ASF files cp "$SPARK_HOME/LICENSE" "$DISTDIR" cp -r "$SPARK_HOME/licenses" "$DISTDIR" cp "$SPARK_HOME/NOTICE" "$DISTDIR" -if [ -e "$SPARK_HOME"/CHANGES.txt ]; then +if [ -e "$SPARK_HOME/CHANGES.txt" ]; then cp "$SPARK_HOME/CHANGES.txt" "$DISTDIR" fi @@ -217,43 +217,43 @@ fi # Make R package - this is used for both CRAN release and packing R layout into distribution if [ "$MAKE_R" == "true" ]; then echo "Building R source package" - R_PACKAGE_VERSION=`grep Version $SPARK_HOME/R/pkg/DESCRIPTION | awk '{print $NF}'` + R_PACKAGE_VERSION=`grep Version "$SPARK_HOME/R/pkg/DESCRIPTION" | awk '{print $NF}'` pushd "$SPARK_HOME/R" > /dev/null # Build source package and run full checks # Do not source the check-cran.sh - it should be run from where it is for it to set SPARK_HOME - NO_TESTS=1 "$SPARK_HOME/"R/check-cran.sh + NO_TESTS=1 "$SPARK_HOME/R/check-cran.sh" # Move R source package to match the Spark release version if the versions are not the same. # NOTE(shivaram): `mv` throws an error on Linux if source and destination are same file if [ "$R_PACKAGE_VERSION" != "$VERSION" ]; then - mv $SPARK_HOME/R/SparkR_"$R_PACKAGE_VERSION".tar.gz $SPARK_HOME/R/SparkR_"$VERSION".tar.gz + mv "$SPARK_HOME/R/SparkR_$R_PACKAGE_VERSION.tar.gz" "$SPARK_HOME/R/SparkR_$VERSION.tar.gz" fi # Install source package to get it to generate vignettes rds files, etc. - VERSION=$VERSION "$SPARK_HOME/"R/install-source-package.sh + VERSION=$VERSION "$SPARK_HOME/R/install-source-package.sh" popd > /dev/null else echo "Skipping building R source package" fi # Copy other things -mkdir "$DISTDIR"/conf -cp "$SPARK_HOME"/conf/*.template "$DISTDIR"/conf +mkdir "$DISTDIR/conf" +cp "$SPARK_HOME"/conf/*.template "$DISTDIR/conf" cp "$SPARK_HOME/README.md" "$DISTDIR" cp -r "$SPARK_HOME/bin" "$DISTDIR" cp -r "$SPARK_HOME/python" "$DISTDIR" # Remove the python distribution from dist/ if we built it if [ "$MAKE_PIP" == "true" ]; then - rm -f $DISTDIR/python/dist/pyspark-*.tar.gz + rm -f "$DISTDIR"/python/dist/pyspark-*.tar.gz fi cp -r "$SPARK_HOME/sbin" "$DISTDIR" # Copy SparkR if it exists -if [ -d "$SPARK_HOME"/R/lib/SparkR ]; then - mkdir -p "$DISTDIR"/R/lib - cp -r "$SPARK_HOME/R/lib/SparkR" "$DISTDIR"/R/lib - cp "$SPARK_HOME/R/lib/sparkr.zip" "$DISTDIR"/R/lib +if [ -d "$SPARK_HOME/R/lib/SparkR" ]; then + mkdir -p "$DISTDIR/R/lib" + cp -r "$SPARK_HOME/R/lib/SparkR" "$DISTDIR/R/lib" + cp "$SPARK_HOME/R/lib/sparkr.zip" "$DISTDIR/R/lib" fi if [ "$MAKE_TGZ" == "true" ]; then From 657cb9541db8508ce64d08cc3de14cd02adf16b5 Mon Sep 17 00:00:00 2001 From: zuotingbing Date: Sun, 2 Apr 2017 15:39:51 +0100 Subject: [PATCH 188/512] [SPARK-20173][SQL][HIVE-THRIFTSERVER] Throw NullPointerException when HiveThriftServer2 is shutdown ## What changes were proposed in this pull request? If the shutdown hook called before the variable `uiTab` is set , it will throw a NullPointerException. ## How was this patch tested? manual tests Author: zuotingbing Closes #17496 from zuotingbing/SPARK-HiveThriftServer2. --- .../apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala index 13c6f11f461c..14553601b1d5 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala @@ -46,7 +46,7 @@ import org.apache.spark.util.{ShutdownHookManager, Utils} */ object HiveThriftServer2 extends Logging { var LOG = LogFactory.getLog(classOf[HiveServer2]) - var uiTab: Option[ThriftServerTab] = _ + var uiTab: Option[ThriftServerTab] = None var listener: HiveThriftServer2Listener = _ /** From 93dbfe705f3e7410a7267e406332ffb3c3077829 Mon Sep 17 00:00:00 2001 From: Felix Cheung Date: Sun, 2 Apr 2017 11:59:27 -0700 Subject: [PATCH 189/512] [SPARK-20159][SPARKR][SQL] Support all catalog API in R ## What changes were proposed in this pull request? Add a set of catalog API in R ``` "currentDatabase", "listColumns", "listDatabases", "listFunctions", "listTables", "recoverPartitions", "refreshByPath", "refreshTable", "setCurrentDatabase", ``` https://github.com/apache/spark/pull/17483/files#diff-6929e6c5e59017ff954e110df20ed7ff ## How was this patch tested? manual tests, unit tests Author: Felix Cheung Closes #17483 from felixcheung/rcatalog. --- R/pkg/DESCRIPTION | 1 + R/pkg/NAMESPACE | 9 + R/pkg/R/SQLContext.R | 233 ----------- R/pkg/R/catalog.R | 479 ++++++++++++++++++++++ R/pkg/R/utils.R | 18 + R/pkg/inst/tests/testthat/test_sparkSQL.R | 66 ++- 6 files changed, 569 insertions(+), 237 deletions(-) create mode 100644 R/pkg/R/catalog.R diff --git a/R/pkg/DESCRIPTION b/R/pkg/DESCRIPTION index 2ea90f7d3666..00dde64324ae 100644 --- a/R/pkg/DESCRIPTION +++ b/R/pkg/DESCRIPTION @@ -32,6 +32,7 @@ Collate: 'pairRDD.R' 'DataFrame.R' 'SQLContext.R' + 'catalog.R' 'WindowSpec.R' 'backend.R' 'broadcast.R' diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 8be7875ad2d5..c02046c94bf4 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -358,9 +358,14 @@ export("as.DataFrame", "clearCache", "createDataFrame", "createExternalTable", + "currentDatabase", "dropTempTable", "dropTempView", "jsonFile", + "listColumns", + "listDatabases", + "listFunctions", + "listTables", "loadDF", "parquetFile", "read.df", @@ -370,7 +375,11 @@ export("as.DataFrame", "read.parquet", "read.stream", "read.text", + "recoverPartitions", + "refreshByPath", + "refreshTable", "setCheckpointDir", + "setCurrentDatabase", "spark.lapply", "spark.addFile", "spark.getSparkFilesRootDirectory", diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R index b75fb0159d50..a1edef7608fa 100644 --- a/R/pkg/R/SQLContext.R +++ b/R/pkg/R/SQLContext.R @@ -569,200 +569,6 @@ tableToDF <- function(tableName) { dataFrame(sdf) } -#' Tables -#' -#' Returns a SparkDataFrame containing names of tables in the given database. -#' -#' @param databaseName name of the database -#' @return a SparkDataFrame -#' @rdname tables -#' @export -#' @examples -#'\dontrun{ -#' sparkR.session() -#' tables("hive") -#' } -#' @name tables -#' @method tables default -#' @note tables since 1.4.0 -tables.default <- function(databaseName = NULL) { - sparkSession <- getSparkSession() - jdf <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getTables", sparkSession, databaseName) - dataFrame(jdf) -} - -tables <- function(x, ...) { - dispatchFunc("tables(databaseName = NULL)", x, ...) -} - -#' Table Names -#' -#' Returns the names of tables in the given database as an array. -#' -#' @param databaseName name of the database -#' @return a list of table names -#' @rdname tableNames -#' @export -#' @examples -#'\dontrun{ -#' sparkR.session() -#' tableNames("hive") -#' } -#' @name tableNames -#' @method tableNames default -#' @note tableNames since 1.4.0 -tableNames.default <- function(databaseName = NULL) { - sparkSession <- getSparkSession() - callJStatic("org.apache.spark.sql.api.r.SQLUtils", - "getTableNames", - sparkSession, - databaseName) -} - -tableNames <- function(x, ...) { - dispatchFunc("tableNames(databaseName = NULL)", x, ...) -} - -#' Cache Table -#' -#' Caches the specified table in-memory. -#' -#' @param tableName The name of the table being cached -#' @return SparkDataFrame -#' @rdname cacheTable -#' @export -#' @examples -#'\dontrun{ -#' sparkR.session() -#' path <- "path/to/file.json" -#' df <- read.json(path) -#' createOrReplaceTempView(df, "table") -#' cacheTable("table") -#' } -#' @name cacheTable -#' @method cacheTable default -#' @note cacheTable since 1.4.0 -cacheTable.default <- function(tableName) { - sparkSession <- getSparkSession() - catalog <- callJMethod(sparkSession, "catalog") - invisible(callJMethod(catalog, "cacheTable", tableName)) -} - -cacheTable <- function(x, ...) { - dispatchFunc("cacheTable(tableName)", x, ...) -} - -#' Uncache Table -#' -#' Removes the specified table from the in-memory cache. -#' -#' @param tableName The name of the table being uncached -#' @return SparkDataFrame -#' @rdname uncacheTable -#' @export -#' @examples -#'\dontrun{ -#' sparkR.session() -#' path <- "path/to/file.json" -#' df <- read.json(path) -#' createOrReplaceTempView(df, "table") -#' uncacheTable("table") -#' } -#' @name uncacheTable -#' @method uncacheTable default -#' @note uncacheTable since 1.4.0 -uncacheTable.default <- function(tableName) { - sparkSession <- getSparkSession() - catalog <- callJMethod(sparkSession, "catalog") - invisible(callJMethod(catalog, "uncacheTable", tableName)) -} - -uncacheTable <- function(x, ...) { - dispatchFunc("uncacheTable(tableName)", x, ...) -} - -#' Clear Cache -#' -#' Removes all cached tables from the in-memory cache. -#' -#' @rdname clearCache -#' @export -#' @examples -#' \dontrun{ -#' clearCache() -#' } -#' @name clearCache -#' @method clearCache default -#' @note clearCache since 1.4.0 -clearCache.default <- function() { - sparkSession <- getSparkSession() - catalog <- callJMethod(sparkSession, "catalog") - invisible(callJMethod(catalog, "clearCache")) -} - -clearCache <- function() { - dispatchFunc("clearCache()") -} - -#' (Deprecated) Drop Temporary Table -#' -#' Drops the temporary table with the given table name in the catalog. -#' If the table has been cached/persisted before, it's also unpersisted. -#' -#' @param tableName The name of the SparkSQL table to be dropped. -#' @seealso \link{dropTempView} -#' @rdname dropTempTable-deprecated -#' @export -#' @examples -#' \dontrun{ -#' sparkR.session() -#' df <- read.df(path, "parquet") -#' createOrReplaceTempView(df, "table") -#' dropTempTable("table") -#' } -#' @name dropTempTable -#' @method dropTempTable default -#' @note dropTempTable since 1.4.0 -dropTempTable.default <- function(tableName) { - if (class(tableName) != "character") { - stop("tableName must be a string.") - } - dropTempView(tableName) -} - -dropTempTable <- function(x, ...) { - .Deprecated("dropTempView") - dispatchFunc("dropTempView(viewName)", x, ...) -} - -#' Drops the temporary view with the given view name in the catalog. -#' -#' Drops the temporary view with the given view name in the catalog. -#' If the view has been cached before, then it will also be uncached. -#' -#' @param viewName the name of the view to be dropped. -#' @return TRUE if the view is dropped successfully, FALSE otherwise. -#' @rdname dropTempView -#' @name dropTempView -#' @export -#' @examples -#' \dontrun{ -#' sparkR.session() -#' df <- read.df(path, "parquet") -#' createOrReplaceTempView(df, "table") -#' dropTempView("table") -#' } -#' @note since 2.0.0 - -dropTempView <- function(viewName) { - sparkSession <- getSparkSession() - if (class(viewName) != "character") { - stop("viewName must be a string.") - } - catalog <- callJMethod(sparkSession, "catalog") - callJMethod(catalog, "dropTempView", viewName) -} - #' Load a SparkDataFrame #' #' Returns the dataset in a data source as a SparkDataFrame @@ -841,45 +647,6 @@ loadDF <- function(x = NULL, ...) { dispatchFunc("loadDF(path = NULL, source = NULL, schema = NULL, ...)", x, ...) } -#' Create an external table -#' -#' Creates an external table based on the dataset in a data source, -#' Returns a SparkDataFrame associated with the external table. -#' -#' The data source is specified by the \code{source} and a set of options(...). -#' If \code{source} is not specified, the default data source configured by -#' "spark.sql.sources.default" will be used. -#' -#' @param tableName a name of the table. -#' @param path the path of files to load. -#' @param source the name of external data source. -#' @param ... additional argument(s) passed to the method. -#' @return A SparkDataFrame. -#' @rdname createExternalTable -#' @export -#' @examples -#'\dontrun{ -#' sparkR.session() -#' df <- createExternalTable("myjson", path="path/to/json", source="json") -#' } -#' @name createExternalTable -#' @method createExternalTable default -#' @note createExternalTable since 1.4.0 -createExternalTable.default <- function(tableName, path = NULL, source = NULL, ...) { - sparkSession <- getSparkSession() - options <- varargsToStrEnv(...) - if (!is.null(path)) { - options[["path"]] <- path - } - catalog <- callJMethod(sparkSession, "catalog") - sdf <- callJMethod(catalog, "createExternalTable", tableName, source, options) - dataFrame(sdf) -} - -createExternalTable <- function(x, ...) { - dispatchFunc("createExternalTable(tableName, path = NULL, source = NULL, ...)", x, ...) -} - #' Create a SparkDataFrame representing the database table accessible via JDBC URL #' #' Additional JDBC database connection properties can be set (...) diff --git a/R/pkg/R/catalog.R b/R/pkg/R/catalog.R new file mode 100644 index 000000000000..07a89f763cde --- /dev/null +++ b/R/pkg/R/catalog.R @@ -0,0 +1,479 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# catalog.R: SparkSession catalog functions + +#' Create an external table +#' +#' Creates an external table based on the dataset in a data source, +#' Returns a SparkDataFrame associated with the external table. +#' +#' The data source is specified by the \code{source} and a set of options(...). +#' If \code{source} is not specified, the default data source configured by +#' "spark.sql.sources.default" will be used. +#' +#' @param tableName a name of the table. +#' @param path the path of files to load. +#' @param source the name of external data source. +#' @param schema the schema of the data for certain data source. +#' @param ... additional argument(s) passed to the method. +#' @return A SparkDataFrame. +#' @rdname createExternalTable +#' @export +#' @examples +#'\dontrun{ +#' sparkR.session() +#' df <- createExternalTable("myjson", path="path/to/json", source="json", schema) +#' } +#' @name createExternalTable +#' @method createExternalTable default +#' @note createExternalTable since 1.4.0 +createExternalTable.default <- function(tableName, path = NULL, source = NULL, schema = NULL, ...) { + sparkSession <- getSparkSession() + options <- varargsToStrEnv(...) + if (!is.null(path)) { + options[["path"]] <- path + } + catalog <- callJMethod(sparkSession, "catalog") + if (is.null(schema)) { + sdf <- callJMethod(catalog, "createExternalTable", tableName, source, options) + } else { + sdf <- callJMethod(catalog, "createExternalTable", tableName, source, schema$jobj, options) + } + dataFrame(sdf) +} + +createExternalTable <- function(x, ...) { + dispatchFunc("createExternalTable(tableName, path = NULL, source = NULL, ...)", x, ...) +} + +#' Cache Table +#' +#' Caches the specified table in-memory. +#' +#' @param tableName The name of the table being cached +#' @return SparkDataFrame +#' @rdname cacheTable +#' @export +#' @examples +#'\dontrun{ +#' sparkR.session() +#' path <- "path/to/file.json" +#' df <- read.json(path) +#' createOrReplaceTempView(df, "table") +#' cacheTable("table") +#' } +#' @name cacheTable +#' @method cacheTable default +#' @note cacheTable since 1.4.0 +cacheTable.default <- function(tableName) { + sparkSession <- getSparkSession() + catalog <- callJMethod(sparkSession, "catalog") + invisible(handledCallJMethod(catalog, "cacheTable", tableName)) +} + +cacheTable <- function(x, ...) { + dispatchFunc("cacheTable(tableName)", x, ...) +} + +#' Uncache Table +#' +#' Removes the specified table from the in-memory cache. +#' +#' @param tableName The name of the table being uncached +#' @return SparkDataFrame +#' @rdname uncacheTable +#' @export +#' @examples +#'\dontrun{ +#' sparkR.session() +#' path <- "path/to/file.json" +#' df <- read.json(path) +#' createOrReplaceTempView(df, "table") +#' uncacheTable("table") +#' } +#' @name uncacheTable +#' @method uncacheTable default +#' @note uncacheTable since 1.4.0 +uncacheTable.default <- function(tableName) { + sparkSession <- getSparkSession() + catalog <- callJMethod(sparkSession, "catalog") + invisible(handledCallJMethod(catalog, "uncacheTable", tableName)) +} + +uncacheTable <- function(x, ...) { + dispatchFunc("uncacheTable(tableName)", x, ...) +} + +#' Clear Cache +#' +#' Removes all cached tables from the in-memory cache. +#' +#' @rdname clearCache +#' @export +#' @examples +#' \dontrun{ +#' clearCache() +#' } +#' @name clearCache +#' @method clearCache default +#' @note clearCache since 1.4.0 +clearCache.default <- function() { + sparkSession <- getSparkSession() + catalog <- callJMethod(sparkSession, "catalog") + invisible(callJMethod(catalog, "clearCache")) +} + +clearCache <- function() { + dispatchFunc("clearCache()") +} + +#' (Deprecated) Drop Temporary Table +#' +#' Drops the temporary table with the given table name in the catalog. +#' If the table has been cached/persisted before, it's also unpersisted. +#' +#' @param tableName The name of the SparkSQL table to be dropped. +#' @seealso \link{dropTempView} +#' @rdname dropTempTable-deprecated +#' @export +#' @examples +#' \dontrun{ +#' sparkR.session() +#' df <- read.df(path, "parquet") +#' createOrReplaceTempView(df, "table") +#' dropTempTable("table") +#' } +#' @name dropTempTable +#' @method dropTempTable default +#' @note dropTempTable since 1.4.0 +dropTempTable.default <- function(tableName) { + if (class(tableName) != "character") { + stop("tableName must be a string.") + } + dropTempView(tableName) +} + +dropTempTable <- function(x, ...) { + .Deprecated("dropTempView") + dispatchFunc("dropTempView(viewName)", x, ...) +} + +#' Drops the temporary view with the given view name in the catalog. +#' +#' Drops the temporary view with the given view name in the catalog. +#' If the view has been cached before, then it will also be uncached. +#' +#' @param viewName the name of the view to be dropped. +#' @return TRUE if the view is dropped successfully, FALSE otherwise. +#' @rdname dropTempView +#' @name dropTempView +#' @export +#' @examples +#' \dontrun{ +#' sparkR.session() +#' df <- read.df(path, "parquet") +#' createOrReplaceTempView(df, "table") +#' dropTempView("table") +#' } +#' @note since 2.0.0 +dropTempView <- function(viewName) { + sparkSession <- getSparkSession() + if (class(viewName) != "character") { + stop("viewName must be a string.") + } + catalog <- callJMethod(sparkSession, "catalog") + callJMethod(catalog, "dropTempView", viewName) +} + +#' Tables +#' +#' Returns a SparkDataFrame containing names of tables in the given database. +#' +#' @param databaseName (optional) name of the database +#' @return a SparkDataFrame +#' @rdname tables +#' @seealso \link{listTables} +#' @export +#' @examples +#'\dontrun{ +#' sparkR.session() +#' tables("hive") +#' } +#' @name tables +#' @method tables default +#' @note tables since 1.4.0 +tables.default <- function(databaseName = NULL) { + # rename column to match previous output schema + withColumnRenamed(listTables(databaseName), "name", "tableName") +} + +tables <- function(x, ...) { + dispatchFunc("tables(databaseName = NULL)", x, ...) +} + +#' Table Names +#' +#' Returns the names of tables in the given database as an array. +#' +#' @param databaseName (optional) name of the database +#' @return a list of table names +#' @rdname tableNames +#' @export +#' @examples +#'\dontrun{ +#' sparkR.session() +#' tableNames("hive") +#' } +#' @name tableNames +#' @method tableNames default +#' @note tableNames since 1.4.0 +tableNames.default <- function(databaseName = NULL) { + sparkSession <- getSparkSession() + callJStatic("org.apache.spark.sql.api.r.SQLUtils", + "getTableNames", + sparkSession, + databaseName) +} + +tableNames <- function(x, ...) { + dispatchFunc("tableNames(databaseName = NULL)", x, ...) +} + +#' Returns the current default database +#' +#' Returns the current default database. +#' +#' @return name of the current default database. +#' @rdname currentDatabase +#' @name currentDatabase +#' @export +#' @examples +#' \dontrun{ +#' sparkR.session() +#' currentDatabase() +#' } +#' @note since 2.2.0 +currentDatabase <- function() { + sparkSession <- getSparkSession() + catalog <- callJMethod(sparkSession, "catalog") + callJMethod(catalog, "currentDatabase") +} + +#' Sets the current default database +#' +#' Sets the current default database. +#' +#' @param databaseName name of the database +#' @rdname setCurrentDatabase +#' @name setCurrentDatabase +#' @export +#' @examples +#' \dontrun{ +#' sparkR.session() +#' setCurrentDatabase("default") +#' } +#' @note since 2.2.0 +setCurrentDatabase <- function(databaseName) { + sparkSession <- getSparkSession() + if (class(databaseName) != "character") { + stop("databaseName must be a string.") + } + catalog <- callJMethod(sparkSession, "catalog") + invisible(handledCallJMethod(catalog, "setCurrentDatabase", databaseName)) +} + +#' Returns a list of databases available +#' +#' Returns a list of databases available. +#' +#' @return a SparkDataFrame of the list of databases. +#' @rdname listDatabases +#' @name listDatabases +#' @export +#' @examples +#' \dontrun{ +#' sparkR.session() +#' listDatabases() +#' } +#' @note since 2.2.0 +listDatabases <- function() { + sparkSession <- getSparkSession() + catalog <- callJMethod(sparkSession, "catalog") + dataFrame(callJMethod(callJMethod(catalog, "listDatabases"), "toDF")) +} + +#' Returns a list of tables in the specified database +#' +#' Returns a list of tables in the specified database. +#' This includes all temporary tables. +#' +#' @param databaseName (optional) name of the database +#' @return a SparkDataFrame of the list of tables. +#' @rdname listTables +#' @name listTables +#' @seealso \link{tables} +#' @export +#' @examples +#' \dontrun{ +#' sparkR.session() +#' listTables() +#' listTables("default") +#' } +#' @note since 2.2.0 +listTables <- function(databaseName = NULL) { + sparkSession <- getSparkSession() + if (!is.null(databaseName) && class(databaseName) != "character") { + stop("databaseName must be a string.") + } + catalog <- callJMethod(sparkSession, "catalog") + jdst <- if (is.null(databaseName)) { + callJMethod(catalog, "listTables") + } else { + handledCallJMethod(catalog, "listTables", databaseName) + } + dataFrame(callJMethod(jdst, "toDF")) +} + +#' Returns a list of columns for the given table in the specified database +#' +#' Returns a list of columns for the given table in the specified database. +#' +#' @param tableName a name of the table. +#' @param databaseName (optional) name of the database +#' @return a SparkDataFrame of the list of column descriptions. +#' @rdname listColumns +#' @name listColumns +#' @export +#' @examples +#' \dontrun{ +#' sparkR.session() +#' listColumns("mytable") +#' } +#' @note since 2.2.0 +listColumns <- function(tableName, databaseName = NULL) { + sparkSession <- getSparkSession() + if (!is.null(databaseName) && class(databaseName) != "character") { + stop("databaseName must be a string.") + } + catalog <- callJMethod(sparkSession, "catalog") + jdst <- if (is.null(databaseName)) { + handledCallJMethod(catalog, "listColumns", tableName) + } else { + handledCallJMethod(catalog, "listColumns", databaseName, tableName) + } + dataFrame(callJMethod(jdst, "toDF")) +} + +#' Returns a list of functions registered in the specified database +#' +#' Returns a list of functions registered in the specified database. +#' This includes all temporary functions. +#' +#' @param databaseName (optional) name of the database +#' @return a SparkDataFrame of the list of function descriptions. +#' @rdname listFunctions +#' @name listFunctions +#' @export +#' @examples +#' \dontrun{ +#' sparkR.session() +#' listFunctions() +#' } +#' @note since 2.2.0 +listFunctions <- function(databaseName = NULL) { + sparkSession <- getSparkSession() + if (!is.null(databaseName) && class(databaseName) != "character") { + stop("databaseName must be a string.") + } + catalog <- callJMethod(sparkSession, "catalog") + jdst <- if (is.null(databaseName)) { + callJMethod(catalog, "listFunctions") + } else { + handledCallJMethod(catalog, "listFunctions", databaseName) + } + dataFrame(callJMethod(jdst, "toDF")) +} + +#' Recover all the partitions in the directory of a table and update the catalog +#' +#' Recover all the partitions in the directory of a table and update the catalog. The name should +#' reference a partitioned table, and not a temporary view. +#' +#' @param tableName a name of the table. +#' @rdname recoverPartitions +#' @name recoverPartitions +#' @export +#' @examples +#' \dontrun{ +#' sparkR.session() +#' recoverPartitions("myTable") +#' } +#' @note since 2.2.0 +recoverPartitions <- function(tableName) { + sparkSession <- getSparkSession() + catalog <- callJMethod(sparkSession, "catalog") + invisible(handledCallJMethod(catalog, "recoverPartitions", tableName)) +} + +#' Invalidate and refresh all the cached metadata of the given table +#' +#' Invalidate and refresh all the cached metadata of the given table. For performance reasons, +#' Spark SQL or the external data source library it uses might cache certain metadata about a +#' table, such as the location of blocks. When those change outside of Spark SQL, users should +#' call this function to invalidate the cache. +#' +#' If this table is cached as an InMemoryRelation, drop the original cached version and make the +#' new version cached lazily. +#' +#' @param tableName a name of the table. +#' @rdname refreshTable +#' @name refreshTable +#' @export +#' @examples +#' \dontrun{ +#' sparkR.session() +#' refreshTable("myTable") +#' } +#' @note since 2.2.0 +refreshTable <- function(tableName) { + sparkSession <- getSparkSession() + catalog <- callJMethod(sparkSession, "catalog") + invisible(handledCallJMethod(catalog, "refreshTable", tableName)) +} + +#' Invalidate and refresh all the cached data and metadata for SparkDataFrame containing path +#' +#' Invalidate and refresh all the cached data (and the associated metadata) for any SparkDataFrame +#' that contains the given data source path. Path matching is by prefix, i.e. "/" would invalidate +#' everything that is cached. +#' +#' @param path the path of the data source. +#' @rdname refreshByPath +#' @name refreshByPath +#' @export +#' @examples +#' \dontrun{ +#' sparkR.session() +#' refreshByPath("/path") +#' } +#' @note since 2.2.0 +refreshByPath <- function(path) { + sparkSession <- getSparkSession() + catalog <- callJMethod(sparkSession, "catalog") + invisible(handledCallJMethod(catalog, "refreshByPath", path)) +} diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R index 810de9917e0b..fbc89e98847b 100644 --- a/R/pkg/R/utils.R +++ b/R/pkg/R/utils.R @@ -846,6 +846,24 @@ captureJVMException <- function(e, method) { # Extract the first message of JVM exception. first <- strsplit(msg[2], "\r?\n\tat")[[1]][1] stop(paste0(rmsg, "analysis error - ", first), call. = FALSE) + } else + if (any(grep("org.apache.spark.sql.catalyst.analysis.NoSuchDatabaseException: ", stacktrace))) { + msg <- strsplit(stacktrace, "org.apache.spark.sql.catalyst.analysis.NoSuchDatabaseException: ", + fixed = TRUE)[[1]] + # Extract "Error in ..." message. + rmsg <- msg[1] + # Extract the first message of JVM exception. + first <- strsplit(msg[2], "\r?\n\tat")[[1]][1] + stop(paste0(rmsg, "no such database - ", first), call. = FALSE) + } else + if (any(grep("org.apache.spark.sql.catalyst.analysis.NoSuchTableException: ", stacktrace))) { + msg <- strsplit(stacktrace, "org.apache.spark.sql.catalyst.analysis.NoSuchTableException: ", + fixed = TRUE)[[1]] + # Extract "Error in ..." message. + rmsg <- msg[1] + # Extract the first message of JVM exception. + first <- strsplit(msg[2], "\r?\n\tat")[[1]][1] + stop(paste0(rmsg, "no such table - ", first), call. = FALSE) } else { stop(stacktrace, call. = FALSE) } diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index 5acf8719d120..ad06711a79a7 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -645,16 +645,20 @@ test_that("test tableNames and tables", { df <- read.json(jsonPath) createOrReplaceTempView(df, "table1") expect_equal(length(tableNames()), 1) - tables <- tables() + expect_equal(length(tableNames("default")), 1) + tables <- listTables() expect_equal(count(tables), 1) + expect_equal(count(tables()), count(tables)) + expect_true("tableName" %in% colnames(tables())) + expect_true(all(c("tableName", "database", "isTemporary") %in% colnames(tables()))) suppressWarnings(registerTempTable(df, "table2")) - tables <- tables() + tables <- listTables() expect_equal(count(tables), 2) suppressWarnings(dropTempTable("table1")) expect_true(dropTempView("table2")) - tables <- tables() + tables <- listTables() expect_equal(count(tables), 0) }) @@ -686,6 +690,9 @@ test_that("test cache, uncache and clearCache", { uncacheTable("table1") clearCache() expect_true(dropTempView("table1")) + + expect_error(uncacheTable("foo"), + "Error in uncacheTable : no such table - Table or view 'foo' not found in database 'default'") }) test_that("insertInto() on a registered table", { @@ -2821,7 +2828,7 @@ test_that("createDataFrame sqlContext parameter backward compatibility", { # more tests for SPARK-16538 createOrReplaceTempView(df, "table") - SparkR::tables() + SparkR::listTables() SparkR::sql("SELECT 1") suppressWarnings(SparkR::sql(sqlContext, "SELECT * FROM table")) suppressWarnings(SparkR::dropTempTable(sqlContext, "table")) @@ -2977,6 +2984,57 @@ test_that("Collect on DataFrame when NAs exists at the top of a timestamp column expect_equal(class(ldf3$col3), c("POSIXct", "POSIXt")) }) +test_that("catalog APIs, currentDatabase, setCurrentDatabase, listDatabases", { + expect_equal(currentDatabase(), "default") + expect_error(setCurrentDatabase("default"), NA) + expect_error(setCurrentDatabase("foo"), + "Error in setCurrentDatabase : analysis error - Database 'foo' does not exist") + dbs <- collect(listDatabases()) + expect_equal(names(dbs), c("name", "description", "locationUri")) + expect_equal(dbs[[1]], "default") +}) + +test_that("catalog APIs, listTables, listColumns, listFunctions", { + tb <- listTables() + count <- count(tables()) + expect_equal(nrow(tb), count) + expect_equal(colnames(tb), c("name", "database", "description", "tableType", "isTemporary")) + + createOrReplaceTempView(as.DataFrame(cars), "cars") + + tb <- listTables() + expect_equal(nrow(tb), count + 1) + tbs <- collect(tb) + expect_true(nrow(tbs[tbs$name == "cars", ]) > 0) + expect_error(listTables("bar"), + "Error in listTables : no such database - Database 'bar' not found") + + c <- listColumns("cars") + expect_equal(nrow(c), 2) + expect_equal(colnames(c), + c("name", "description", "dataType", "nullable", "isPartition", "isBucket")) + expect_equal(collect(c)[[1]][[1]], "speed") + expect_error(listColumns("foo", "default"), + "Error in listColumns : analysis error - Table 'foo' does not exist in database 'default'") + + f <- listFunctions() + expect_true(nrow(f) >= 200) # 250 + expect_equal(colnames(f), + c("name", "database", "description", "className", "isTemporary")) + expect_equal(take(orderBy(f, "className"), 1)$className, + "org.apache.spark.sql.catalyst.expressions.Abs") + expect_error(listFunctions("foo_db"), + "Error in listFunctions : analysis error - Database 'foo_db' does not exist") + + # recoverPartitions does not work with tempory view + expect_error(recoverPartitions("cars"), + "no such table - Table or view 'cars' not found in database 'default'") + expect_error(refreshTable("cars"), NA) + expect_error(refreshByPath("/"), NA) + + dropTempView("cars") +}) + compare_list <- function(list1, list2) { # get testthat to show the diff by first making the 2 lists equal in length expect_equal(length(list1), length(list2)) From 2a903a1eec46e3bd58af0fcbc57e76752d9c18b3 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Mon, 3 Apr 2017 10:56:54 +0200 Subject: [PATCH 190/512] [SPARK-19985][ML] Fixed copy method for some ML Models ## What changes were proposed in this pull request? Some ML Models were using `defaultCopy` which expects a default constructor, and others were not setting the parent estimator. This change fixes these by creating a new instance of the model and explicitly setting values and parent. ## How was this patch tested? Added `MLTestingUtils.checkCopy` to the offending models to tests to verify the copy is made and parent is set. Author: Bryan Cutler Closes #17326 from BryanCutler/ml-model-copy-error-SPARK-19985. --- .../MultilayerPerceptronClassifier.scala | 3 ++- .../ml/feature/BucketedRandomProjectionLSH.scala | 5 ++++- .../org/apache/spark/ml/feature/MinHashLSH.scala | 5 ++++- .../scala/org/apache/spark/ml/feature/RFormula.scala | 6 ++++-- .../MultilayerPerceptronClassifierSuite.scala | 1 + .../ml/feature/BucketedRandomProjectionLSHSuite.scala | 6 ++++-- .../org/apache/spark/ml/feature/MinHashLSHSuite.scala | 11 ++++++++++- .../org/apache/spark/ml/feature/RFormulaSuite.scala | 1 + 8 files changed, 30 insertions(+), 8 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala index 95c1337ed560..ec39f964e213 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala @@ -329,7 +329,8 @@ class MultilayerPerceptronClassificationModel private[ml] ( @Since("1.5.0") override def copy(extra: ParamMap): MultilayerPerceptronClassificationModel = { - copyValues(new MultilayerPerceptronClassificationModel(uid, layers, weights), extra) + val copied = new MultilayerPerceptronClassificationModel(uid, layers, weights).setParent(parent) + copyValues(copied, extra) } @Since("2.0.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala index cbac16345a29..36a46ca6ff4b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala @@ -96,7 +96,10 @@ class BucketedRandomProjectionLSHModel private[ml]( } @Since("2.1.0") - override def copy(extra: ParamMap): this.type = defaultCopy(extra) + override def copy(extra: ParamMap): BucketedRandomProjectionLSHModel = { + val copied = new BucketedRandomProjectionLSHModel(uid, randUnitVectors).setParent(parent) + copyValues(copied, extra) + } @Since("2.1.0") override def write: MLWriter = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala index 620e1fbb09ff..145422a05919 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala @@ -86,7 +86,10 @@ class MinHashLSHModel private[ml]( } @Since("2.1.0") - override def copy(extra: ParamMap): this.type = defaultCopy(extra) + override def copy(extra: ParamMap): MinHashLSHModel = { + val copied = new MinHashLSHModel(uid, randCoefficients).setParent(parent) + copyValues(copied, extra) + } @Since("2.1.0") override def write: MLWriter = new MinHashLSHModel.MinHashLSHModelWriter(this) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala index 389898666eb8..5a3e2929f5f5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala @@ -268,8 +268,10 @@ class RFormulaModel private[feature]( } @Since("1.5.0") - override def copy(extra: ParamMap): RFormulaModel = copyValues( - new RFormulaModel(uid, resolvedFormula, pipelineModel)) + override def copy(extra: ParamMap): RFormulaModel = { + val copied = new RFormulaModel(uid, resolvedFormula, pipelineModel).setParent(parent) + copyValues(copied, extra) + } @Since("2.0.0") override def toString: String = s"RFormulaModel($resolvedFormula) (uid=$uid)" diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala index 41684d92be33..7700099caac3 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala @@ -74,6 +74,7 @@ class MultilayerPerceptronClassifierSuite .setMaxIter(100) .setSolver("l-bfgs") val model = trainer.fit(dataset) + MLTestingUtils.checkCopy(model) val result = model.transform(dataset) val predictionAndLabels = result.select("prediction", "label").collect() predictionAndLabels.foreach { case Row(p: Double, l: Double) => diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala index 91eac9e73331..cc81da5c66e6 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala @@ -23,7 +23,7 @@ import breeze.numerics.constants.Pi import org.apache.spark.SparkFunSuite import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.Dataset @@ -89,10 +89,12 @@ class BucketedRandomProjectionLSHSuite .setOutputCol("values") .setBucketLength(1.0) .setSeed(12345) - val unitVectors = brp.fit(dataset).randUnitVectors + val brpModel = brp.fit(dataset) + val unitVectors = brpModel.randUnitVectors unitVectors.foreach { v: Vector => assert(Vectors.norm(v, 2.0) ~== 1.0 absTol 1e-14) } + MLTestingUtils.checkCopy(brpModel) } test("BucketedRandomProjectionLSH: test of LSH property") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala index a2f009310fd7..0ddf097a6eb2 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.Dataset @@ -57,6 +57,15 @@ class MinHashLSHSuite extends SparkFunSuite with MLlibTestSparkContext with Defa testEstimatorAndModelReadWrite(mh, dataset, settings, settings, checkModelData) } + test("Model copy and uid checks") { + val mh = new MinHashLSH() + .setInputCol("keys") + .setOutputCol("values") + val model = mh.fit(dataset) + assert(mh.uid === model.uid) + MLTestingUtils.checkCopy(model) + } + test("hashFunction") { val model = new MinHashLSHModel("mh", randCoefficients = Array((0, 1), (1, 2), (3, 0))) val res = model.hashFunction(Vectors.sparse(10, Seq((2, 1.0), (3, 1.0), (5, 1.0), (7, 1.0)))) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala index c664460d7d8b..5cfd59e6b88a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala @@ -37,6 +37,7 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul val formula = new RFormula().setFormula("id ~ v1 + v2") val original = Seq((0, 1.0, 3.0), (2, 2.0, 5.0)).toDF("id", "v1", "v2") val model = formula.fit(original) + MLTestingUtils.checkCopy(model) val result = model.transform(original) val resultSchema = model.transformSchema(original.schema) val expected = Seq( From cff11fd20e869d14106d2d0f17df67161c44d476 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Mon, 3 Apr 2017 10:07:41 +0100 Subject: [PATCH 191/512] [SPARK-20166][SQL] Use XXX for ISO 8601 timezone instead of ZZ (FastDateFormat specific) in CSV/JSON timeformat options ## What changes were proposed in this pull request? This PR proposes to use `XXX` format instead of `ZZ`. `ZZ` seems a `FastDateFormat` specific. `ZZ` supports "ISO 8601 extended format time zones" but it seems `FastDateFormat` specific option. I misunderstood this is compatible format with `SimpleDateFormat` when this change is introduced. Please see [SimpleDateFormat documentation]( https://docs.oracle.com/javase/7/docs/api/java/text/SimpleDateFormat.html#iso8601timezone) and [FastDateFormat documentation](https://commons.apache.org/proper/commons-lang/apidocs/org/apache/commons/lang3/time/FastDateFormat.html). It seems we better replace `ZZ` to `XXX` because they look using the same strategy - [FastDateParser.java#L930](https://github.com/apache/commons-lang/blob/8767cd4f1a6af07093c1e6c422dae8e574be7e5e/src/main/java/org/apache/commons/lang3/time/FastDateParser.java#L930), [FastDateParser.java#L932-L951 ](https://github.com/apache/commons-lang/blob/8767cd4f1a6af07093c1e6c422dae8e574be7e5e/src/main/java/org/apache/commons/lang3/time/FastDateParser.java#L932-L951) and [FastDateParser.java#L596-L601](https://github.com/apache/commons-lang/blob/8767cd4f1a6af07093c1e6c422dae8e574be7e5e/src/main/java/org/apache/commons/lang3/time/FastDateParser.java#L596-L601). I also checked the codes and manually debugged it for sure. It seems both cases use the same pattern `( Z|(?:[+-]\\d{2}(?::)\\d{2}))`. _Note that this should be rather a fix about documentation and not the behaviour change because `ZZ` seems invalid date format in `SimpleDateFormat` as documented in `DataFrameReader` and etc, and both `ZZ` and `XXX` look identically working with `FastDateFormat`_ Current documentation is as below: ``` *
    • `timestampFormat` (default `yyyy-MM-dd'T'HH:mm:ss.SSSZZ`): sets the string that * indicates a timestamp format. Custom date formats follow the formats at * `java.text.SimpleDateFormat`. This applies to timestamp type.
    • ``` ## How was this patch tested? Existing tests should cover this. Also, manually tested as below (BTW, I don't think these are worth being added as tests within Spark): **Parse** ```scala scala> new java.text.SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss.SSSXXX").parse("2017-03-21T00:00:00.000-11:00") res4: java.util.Date = Tue Mar 21 20:00:00 KST 2017 scala> new java.text.SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss.SSSXXX").parse("2017-03-21T00:00:00.000Z") res10: java.util.Date = Tue Mar 21 09:00:00 KST 2017 scala> new java.text.SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss.SSSZZ").parse("2017-03-21T00:00:00.000-11:00") java.text.ParseException: Unparseable date: "2017-03-21T00:00:00.000-11:00" at java.text.DateFormat.parse(DateFormat.java:366) ... 48 elided scala> new java.text.SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss.SSSZZ").parse("2017-03-21T00:00:00.000Z") java.text.ParseException: Unparseable date: "2017-03-21T00:00:00.000Z" at java.text.DateFormat.parse(DateFormat.java:366) ... 48 elided ``` ```scala scala> org.apache.commons.lang3.time.FastDateFormat.getInstance("yyyy-MM-dd'T'HH:mm:ss.SSSXXX").parse("2017-03-21T00:00:00.000-11:00") res7: java.util.Date = Tue Mar 21 20:00:00 KST 2017 scala> org.apache.commons.lang3.time.FastDateFormat.getInstance("yyyy-MM-dd'T'HH:mm:ss.SSSXXX").parse("2017-03-21T00:00:00.000Z") res1: java.util.Date = Tue Mar 21 09:00:00 KST 2017 scala> org.apache.commons.lang3.time.FastDateFormat.getInstance("yyyy-MM-dd'T'HH:mm:ss.SSSZZ").parse("2017-03-21T00:00:00.000-11:00") res8: java.util.Date = Tue Mar 21 20:00:00 KST 2017 scala> org.apache.commons.lang3.time.FastDateFormat.getInstance("yyyy-MM-dd'T'HH:mm:ss.SSSZZ").parse("2017-03-21T00:00:00.000Z") res2: java.util.Date = Tue Mar 21 09:00:00 KST 2017 ``` **Format** ```scala scala> new java.text.SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss.SSSXXX").format(new java.text.SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss.SSSXXX").parse("2017-03-21T00:00:00.000-11:00")) res6: String = 2017-03-21T20:00:00.000+09:00 ``` ```scala scala> val fd = org.apache.commons.lang3.time.FastDateFormat.getInstance("yyyy-MM-dd'T'HH:mm:ss.SSSZZ") fd: org.apache.commons.lang3.time.FastDateFormat = FastDateFormat[yyyy-MM-dd'T'HH:mm:ss.SSSZZ,ko_KR,Asia/Seoul] scala> fd.format(fd.parse("2017-03-21T00:00:00.000-11:00")) res1: String = 2017-03-21T20:00:00.000+09:00 scala> val fd = org.apache.commons.lang3.time.FastDateFormat.getInstance("yyyy-MM-dd'T'HH:mm:ss.SSSXXX") fd: org.apache.commons.lang3.time.FastDateFormat = FastDateFormat[yyyy-MM-dd'T'HH:mm:ss.SSSXXX,ko_KR,Asia/Seoul] scala> fd.format(fd.parse("2017-03-21T00:00:00.000-11:00")) res2: String = 2017-03-21T20:00:00.000+09:00 ``` Author: hyukjinkwon Closes #17489 from HyukjinKwon/SPARK-20166. --- python/pyspark/sql/readwriter.py | 8 ++++---- python/pyspark/sql/streaming.py | 4 ++-- .../org/apache/spark/sql/catalyst/json/JSONOptions.scala | 2 +- .../main/scala/org/apache/spark/sql/DataFrameReader.scala | 4 ++-- .../main/scala/org/apache/spark/sql/DataFrameWriter.scala | 4 ++-- .../spark/sql/execution/datasources/csv/CSVOptions.scala | 2 +- .../org/apache/spark/sql/streaming/DataStreamReader.scala | 4 ++-- .../spark/sql/execution/datasources/csv/CSVSuite.scala | 2 +- 8 files changed, 15 insertions(+), 15 deletions(-) diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 5e732b4bec8f..d912f395dafc 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -223,7 +223,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, :param timestampFormat: sets the string that indicates a timestamp format. Custom date formats follow the formats at ``java.text.SimpleDateFormat``. This applies to timestamp type. If None is set, it uses the - default value, ``yyyy-MM-dd'T'HH:mm:ss.SSSZZ``. + default value, ``yyyy-MM-dd'T'HH:mm:ss.SSSXXX``. :param wholeFile: parse one record, which may span multiple lines, per file. If None is set, it uses the default value, ``false``. @@ -363,7 +363,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non :param timestampFormat: sets the string that indicates a timestamp format. Custom date formats follow the formats at ``java.text.SimpleDateFormat``. This applies to timestamp type. If None is set, it uses the - default value, ``yyyy-MM-dd'T'HH:mm:ss.SSSZZ``. + default value, ``yyyy-MM-dd'T'HH:mm:ss.SSSXXX``. :param maxColumns: defines a hard limit of how many columns a record can have. If None is set, it uses the default value, ``20480``. :param maxCharsPerColumn: defines the maximum number of characters allowed for any given @@ -653,7 +653,7 @@ def json(self, path, mode=None, compression=None, dateFormat=None, timestampForm :param timestampFormat: sets the string that indicates a timestamp format. Custom date formats follow the formats at ``java.text.SimpleDateFormat``. This applies to timestamp type. If None is set, it uses the - default value, ``yyyy-MM-dd'T'HH:mm:ss.SSSZZ``. + default value, ``yyyy-MM-dd'T'HH:mm:ss.SSSXXX``. >>> df.write.json(os.path.join(tempfile.mkdtemp(), 'data')) """ @@ -745,7 +745,7 @@ def csv(self, path, mode=None, compression=None, sep=None, quote=None, escape=No :param timestampFormat: sets the string that indicates a timestamp format. Custom date formats follow the formats at ``java.text.SimpleDateFormat``. This applies to timestamp type. If None is set, it uses the - default value, ``yyyy-MM-dd'T'HH:mm:ss.SSSZZ``. + default value, ``yyyy-MM-dd'T'HH:mm:ss.SSSXXX``. :param ignoreLeadingWhiteSpace: a flag indicating whether or not leading whitespaces from values being written should be skipped. If None is set, it uses the default value, ``true``. diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index 27d6725615a4..3b604963415f 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -457,7 +457,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, :param timestampFormat: sets the string that indicates a timestamp format. Custom date formats follow the formats at ``java.text.SimpleDateFormat``. This applies to timestamp type. If None is set, it uses the - default value, ``yyyy-MM-dd'T'HH:mm:ss.SSSZZ``. + default value, ``yyyy-MM-dd'T'HH:mm:ss.SSSXXX``. :param wholeFile: parse one record, which may span multiple lines, per file. If None is set, it uses the default value, ``false``. @@ -581,7 +581,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non :param timestampFormat: sets the string that indicates a timestamp format. Custom date formats follow the formats at ``java.text.SimpleDateFormat``. This applies to timestamp type. If None is set, it uses the - default value, ``yyyy-MM-dd'T'HH:mm:ss.SSSZZ``. + default value, ``yyyy-MM-dd'T'HH:mm:ss.SSSXXX``. :param maxColumns: defines a hard limit of how many columns a record can have. If None is set, it uses the default value, ``20480``. :param maxCharsPerColumn: defines the maximum number of characters allowed for any given diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala index c22b1ade4e64..23ba5ed4d50d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala @@ -79,7 +79,7 @@ private[sql] class JSONOptions( val timestampFormat: FastDateFormat = FastDateFormat.getInstance( - parameters.getOrElse("timestampFormat", "yyyy-MM-dd'T'HH:mm:ss.SSSZZ"), timeZone, Locale.US) + parameters.getOrElse("timestampFormat", "yyyy-MM-dd'T'HH:mm:ss.SSSXXX"), timeZone, Locale.US) val wholeFile = parameters.get("wholeFile").map(_.toBoolean).getOrElse(false) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 6c238618f2af..2b8537c3d4a6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -320,7 +320,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { *
    • `dateFormat` (default `yyyy-MM-dd`): sets the string that indicates a date format. * Custom date formats follow the formats at `java.text.SimpleDateFormat`. This applies to * date type.
    • - *
    • `timestampFormat` (default `yyyy-MM-dd'T'HH:mm:ss.SSSZZ`): sets the string that + *
    • `timestampFormat` (default `yyyy-MM-dd'T'HH:mm:ss.SSSXXX`): sets the string that * indicates a timestamp format. Custom date formats follow the formats at * `java.text.SimpleDateFormat`. This applies to timestamp type.
    • *
    • `wholeFile` (default `false`): parse one record, which may span multiple lines, @@ -502,7 +502,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { *
    • `dateFormat` (default `yyyy-MM-dd`): sets the string that indicates a date format. * Custom date formats follow the formats at `java.text.SimpleDateFormat`. This applies to * date type.
    • - *
    • `timestampFormat` (default `yyyy-MM-dd'T'HH:mm:ss.SSSZZ`): sets the string that + *
    • `timestampFormat` (default `yyyy-MM-dd'T'HH:mm:ss.SSSXXX`): sets the string that * indicates a timestamp format. Custom date formats follow the formats at * `java.text.SimpleDateFormat`. This applies to timestamp type.
    • *
    • `maxColumns` (default `20480`): defines a hard limit of how many columns diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index e973d0bc6d09..338a6e1314d9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -477,7 +477,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { *
    • `dateFormat` (default `yyyy-MM-dd`): sets the string that indicates a date format. * Custom date formats follow the formats at `java.text.SimpleDateFormat`. This applies to * date type.
    • - *
    • `timestampFormat` (default `yyyy-MM-dd'T'HH:mm:ss.SSSZZ`): sets the string that + *
    • `timestampFormat` (default `yyyy-MM-dd'T'HH:mm:ss.SSSXXX`): sets the string that * indicates a timestamp format. Custom date formats follow the formats at * `java.text.SimpleDateFormat`. This applies to timestamp type.
    • *
    @@ -583,7 +583,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { *
  • `dateFormat` (default `yyyy-MM-dd`): sets the string that indicates a date format. * Custom date formats follow the formats at `java.text.SimpleDateFormat`. This applies to * date type.
  • - *
  • `timestampFormat` (default `yyyy-MM-dd'T'HH:mm:ss.SSSZZ`): sets the string that + *
  • `timestampFormat` (default `yyyy-MM-dd'T'HH:mm:ss.SSSXXX`): sets the string that * indicates a timestamp format. Custom date formats follow the formats at * `java.text.SimpleDateFormat`. This applies to timestamp type.
  • *
  • `ignoreLeadingWhiteSpace` (default `true`): a flag indicating whether or not leading diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala index e7b79e0cbfd1..4994b8dc8052 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala @@ -126,7 +126,7 @@ class CSVOptions( val timestampFormat: FastDateFormat = FastDateFormat.getInstance( - parameters.getOrElse("timestampFormat", "yyyy-MM-dd'T'HH:mm:ss.SSSZZ"), timeZone, Locale.US) + parameters.getOrElse("timestampFormat", "yyyy-MM-dd'T'HH:mm:ss.SSSXXX"), timeZone, Locale.US) val wholeFile = parameters.get("wholeFile").map(_.toBoolean).getOrElse(false) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index 997ca286597d..c3a9cfc08517 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -201,7 +201,7 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo *
  • `dateFormat` (default `yyyy-MM-dd`): sets the string that indicates a date format. * Custom date formats follow the formats at `java.text.SimpleDateFormat`. This applies to * date type.
  • - *
  • `timestampFormat` (default `yyyy-MM-dd'T'HH:mm:ss.SSSZZ`): sets the string that + *
  • `timestampFormat` (default `yyyy-MM-dd'T'HH:mm:ss.SSSXXX`): sets the string that * indicates a timestamp format. Custom date formats follow the formats at * `java.text.SimpleDateFormat`. This applies to timestamp type.
  • *
  • `wholeFile` (default `false`): parse one record, which may span multiple lines, @@ -252,7 +252,7 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo *
  • `dateFormat` (default `yyyy-MM-dd`): sets the string that indicates a date format. * Custom date formats follow the formats at `java.text.SimpleDateFormat`. This applies to * date type.
  • - *
  • `timestampFormat` (default `yyyy-MM-dd'T'HH:mm:ss.SSSZZ`): sets the string that + *
  • `timestampFormat` (default `yyyy-MM-dd'T'HH:mm:ss.SSSXXX`): sets the string that * indicates a timestamp format. Custom date formats follow the formats at * `java.text.SimpleDateFormat`. This applies to timestamp type.
  • *
  • `maxColumns` (default `20480`): defines a hard limit of how many columns diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index d70c47f4e237..352dba79a4c0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -766,7 +766,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { .option("header", "true") .load(iso8601timestampsPath) - val iso8501 = FastDateFormat.getInstance("yyyy-MM-dd'T'HH:mm:ss.SSSZZ", Locale.US) + val iso8501 = FastDateFormat.getInstance("yyyy-MM-dd'T'HH:mm:ss.SSSXXX", Locale.US) val expectedTimestamps = timestamps.collect().map { r => // This should be ISO8601 formatted string. Row(iso8501.format(r.toSeq.head)) From 364b0db75308ddd346b4ab1e032680e8eb4c1753 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Mon, 3 Apr 2017 10:09:11 +0100 Subject: [PATCH 192/512] [MINOR][DOCS] Replace non-breaking space to normal spaces that breaks rendering markdown MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # What changes were proposed in this pull request? It seems there are several non-breaking spaces were inserted into several `.md`s and they look breaking rendering markdown files. These are different. For example, this can be checked via `python` as below: ```python >>> " " '\xc2\xa0' >>> " " ' ' ``` _Note that it seems this PR description automatically replaces non-breaking spaces into normal spaces. Please open a `vi` and copy and paste it into `python` to verify this (do not copy the characters here)._ I checked the output below in Sapari and Chrome on Mac OS and, Internal Explorer on Windows 10. **Before** ![2017-04-03 12 37 17](https://cloud.githubusercontent.com/assets/6477701/24594655/50aaba02-186a-11e7-80bb-d34b17a3398a.png) ![2017-04-03 12 36 57](https://cloud.githubusercontent.com/assets/6477701/24594654/50a855e6-186a-11e7-94e2-661e56544b0f.png) **After** ![2017-04-03 12 36 46](https://cloud.githubusercontent.com/assets/6477701/24594657/53c2545c-186a-11e7-9a73-00529afbfd75.png) ![2017-04-03 12 36 31](https://cloud.githubusercontent.com/assets/6477701/24594658/53c286c0-186a-11e7-99c9-e66b1f510fe7.png) ## How was this patch tested? Manually checking. These instances were found via ``` grep --include=*.scala --include=*.python --include=*.java --include=*.r --include=*.R --include=*.md --include=*.r -r -I " " . ``` in Mac OS. It seems there are several instances more as below: ``` ./docs/sql-programming-guide.md: │   ├── ... ./docs/sql-programming-guide.md: │   │ ./docs/sql-programming-guide.md: │   ├── country=US ./docs/sql-programming-guide.md: │   │   └── data.parquet ./docs/sql-programming-guide.md: │   ├── country=CN ./docs/sql-programming-guide.md: │   │   └── data.parquet ./docs/sql-programming-guide.md: │   └── ... ./docs/sql-programming-guide.md:    ├── ... ./docs/sql-programming-guide.md:    │ ./docs/sql-programming-guide.md:    ├── country=US ./docs/sql-programming-guide.md:    │   └── data.parquet ./docs/sql-programming-guide.md:    ├── country=CN ./docs/sql-programming-guide.md:    │   └── data.parquet ./docs/sql-programming-guide.md:    └── ... ./sql/core/src/test/README.md:│   ├── *.avdl # Testing Avro IDL(s) ./sql/core/src/test/README.md:│   └── *.avpr # !! NO TOUCH !! Protocol files generated from Avro IDL(s) ./sql/core/src/test/README.md:│   ├── gen-avro.sh # Script used to generate Java code for Avro ./sql/core/src/test/README.md:│   └── gen-thrift.sh # Script used to generate Java code for Thrift ``` These seems generated via `tree` command which inserts non-breaking spaces. They do not look causing any problem for rendering within code blocks and I did not fix it to reduce the overhead to manually replace it when it is overwritten via `tree` command in the future. Author: hyukjinkwon Closes #17517 from HyukjinKwon/non-breaking-space. --- README.md | 2 +- docs/building-spark.md | 2 +- docs/monitoring.md | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index d0eca1ddea28..1e521a7e7b17 100644 --- a/README.md +++ b/README.md @@ -97,7 +97,7 @@ building for particular Hive and Hive Thriftserver distributions. Please refer to the [Configuration Guide](http://spark.apache.org/docs/latest/configuration.html) in the online documentation for an overview on how to configure Spark. -## Contributing +## Contributing Please review the [Contribution to Spark guide](http://spark.apache.org/contributing.html) for information on how to get started contributing to the project. diff --git a/docs/building-spark.md b/docs/building-spark.md index 8353b7a520b8..e99b70f7a8b4 100644 --- a/docs/building-spark.md +++ b/docs/building-spark.md @@ -154,7 +154,7 @@ Developers who compile Spark frequently may want to speed up compilation; e.g., developers who build with SBT). For more information about how to do this, refer to the [Useful Developer Tools page](http://spark.apache.org/developer-tools.html#reducing-build-times). -## Encrypted Filesystems +## Encrypted Filesystems When building on an encrypted filesystem (if your home directory is encrypted, for example), then the Spark build might fail with a "Filename too long" error. As a workaround, add the following in the configuration args of the `scala-maven-plugin` in the project `pom.xml`: diff --git a/docs/monitoring.md b/docs/monitoring.md index 80519525af0c..6cbc6660e816 100644 --- a/docs/monitoring.md +++ b/docs/monitoring.md @@ -257,7 +257,7 @@ In the API, an application is referenced by its application ID, `[app-id]`. When running on YARN, each application may have multiple attempts, but there are attempt IDs only for applications in cluster mode, not applications in client mode. Applications in YARN cluster mode can be identified by their `[attempt-id]`. In the API listed below, when running in YARN cluster mode, -`[app-id]` will actually be `[base-app-id]/[attempt-id]`, where `[base-app-id]` is the YARN application ID. +`[app-id]` will actually be `[base-app-id]/[attempt-id]`, where `[base-app-id]` is the YARN application ID. From fb5869f2cf94217b3e254e2d0820507dc83a25cc Mon Sep 17 00:00:00 2001 From: Denis Bolshakov Date: Mon, 3 Apr 2017 10:16:07 +0100 Subject: [PATCH 193/512] [SPARK-9002][CORE] KryoSerializer initialization does not include 'Array[Int]' [SPARK-9002][CORE] KryoSerializer initialization does not include 'Array[Int]' ## What changes were proposed in this pull request? Array[Int] has been registered in KryoSerializer. The following file has been changed core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala ## How was this patch tested? First, the issue was reproduced by new unit test. Then, the issue was fixed to pass the failed test. Author: Denis Bolshakov Closes #17482 from dbolshak/SPARK-9002. --- .../org/apache/spark/serializer/KryoSerializer.scala | 7 +++++++ .../apache/spark/serializer/KryoSerializerSuite.scala | 10 ++++++++++ 2 files changed, 17 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index 03815631a604..6fc66e2374bd 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -384,9 +384,16 @@ private[serializer] object KryoSerializer { classOf[HighlyCompressedMapStatus], classOf[CompactBuffer[_]], classOf[BlockManagerId], + classOf[Array[Boolean]], classOf[Array[Byte]], classOf[Array[Short]], + classOf[Array[Int]], classOf[Array[Long]], + classOf[Array[Float]], + classOf[Array[Double]], + classOf[Array[Char]], + classOf[Array[String]], + classOf[Array[Array[String]]], classOf[BoundedPriorityQueue[_]], classOf[SparkConf] ) diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala index a30653bb36fa..7c3922e47fbb 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala @@ -76,6 +76,9 @@ class KryoSerializerSuite extends SparkFunSuite with SharedSparkContext { } test("basic types") { + val conf = new SparkConf(false) + conf.set("spark.kryo.registrationRequired", "true") + val ser = new KryoSerializer(conf).newInstance() def check[T: ClassTag](t: T) { assert(ser.deserialize[T](ser.serialize(t)) === t) @@ -106,6 +109,9 @@ class KryoSerializerSuite extends SparkFunSuite with SharedSparkContext { } test("pairs") { + val conf = new SparkConf(false) + conf.set("spark.kryo.registrationRequired", "true") + val ser = new KryoSerializer(conf).newInstance() def check[T: ClassTag](t: T) { assert(ser.deserialize[T](ser.serialize(t)) === t) @@ -130,12 +136,16 @@ class KryoSerializerSuite extends SparkFunSuite with SharedSparkContext { } test("Scala data structures") { + val conf = new SparkConf(false) + conf.set("spark.kryo.registrationRequired", "true") + val ser = new KryoSerializer(conf).newInstance() def check[T: ClassTag](t: T) { assert(ser.deserialize[T](ser.serialize(t)) === t) } check(List[Int]()) check(List[Int](1, 2, 3)) + check(Seq[Int](1, 2, 3)) check(List[String]()) check(List[String]("x", "y", "z")) check(None) From 4d28e8430d11323f08657ca8f3251ca787c45501 Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Mon, 3 Apr 2017 11:42:33 +0200 Subject: [PATCH 194/512] [SPARK-19969][ML] Imputer doc and example ## What changes were proposed in this pull request? Add docs and examples for spark.ml.feature.Imputer. Currently scala and Java examples are included. Python example will be added after https://github.com/apache/spark/pull/17316 ## How was this patch tested? local doc generation and example execution Author: Yuhao Yang Closes #17324 from hhbyyh/imputerdoc. --- docs/ml-features.md | 66 +++++++++++++++++ .../spark/examples/ml/JavaImputerExample.java | 71 +++++++++++++++++++ .../src/main/python/ml/imputer_example.py | 50 +++++++++++++ .../spark/examples/ml/ImputerExample.scala | 56 +++++++++++++++ .../org/apache/spark/ml/feature/Imputer.scala | 2 +- 5 files changed, 244 insertions(+), 1 deletion(-) create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaImputerExample.java create mode 100644 examples/src/main/python/ml/imputer_example.py create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/ImputerExample.scala diff --git a/docs/ml-features.md b/docs/ml-features.md index dad1c6db18f8..e19fba249fb2 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -1284,6 +1284,72 @@ for more details on the API. + +## Imputer + +The `Imputer` transformer completes missing values in a dataset, either using the mean or the +median of the columns in which the missing values are located. The input columns should be of +`DoubleType` or `FloatType`. Currently `Imputer` does not support categorical features and possibly +creates incorrect values for columns containing categorical features. + +**Note** all `null` values in the input columns are treated as missing, and so are also imputed. + +**Examples** + +Suppose that we have a DataFrame with the columns `a` and `b`: + +~~~ + a | b +------------|----------- + 1.0 | Double.NaN + 2.0 | Double.NaN + Double.NaN | 3.0 + 4.0 | 4.0 + 5.0 | 5.0 +~~~ + +In this example, Imputer will replace all occurrences of `Double.NaN` (the default for the missing value) +with the mean (the default imputation strategy) computed from the other values in the corresponding columns. +In this example, the surrogate values for columns `a` and `b` are 3.0 and 4.0 respectively. After +transformation, the missing values in the output columns will be replaced by the surrogate value for +the relevant column. + +~~~ + a | b | out_a | out_b +------------|------------|-------|------- + 1.0 | Double.NaN | 1.0 | 4.0 + 2.0 | Double.NaN | 2.0 | 4.0 + Double.NaN | 3.0 | 3.0 | 3.0 + 4.0 | 4.0 | 4.0 | 4.0 + 5.0 | 5.0 | 5.0 | 5.0 +~~~ + +
    +
    + +Refer to the [Imputer Scala docs](api/scala/index.html#org.apache.spark.ml.feature.Imputer) +for more details on the API. + +{% include_example scala/org/apache/spark/examples/ml/ImputerExample.scala %} +
    + +
    + +Refer to the [Imputer Java docs](api/java/org/apache/spark/ml/feature/Imputer.html) +for more details on the API. + +{% include_example java/org/apache/spark/examples/ml/JavaImputerExample.java %} +
    + +
    + +Refer to the [Imputer Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.Imputer) +for more details on the API. + +{% include_example python/ml/imputer_example.py %} +
    +
    + # Feature Selectors ## VectorSlicer diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaImputerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaImputerExample.java new file mode 100644 index 000000000000..ac40ccd9dbd7 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaImputerExample.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +// $example on$ +import java.util.Arrays; +import java.util.List; + +import org.apache.spark.ml.feature.Imputer; +import org.apache.spark.ml.feature.ImputerModel; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.types.*; +// $example off$ + +import static org.apache.spark.sql.types.DataTypes.*; + +/** + * An example demonstrating Imputer. + * Run with: + * bin/run-example ml.JavaImputerExample + */ +public class JavaImputerExample { + public static void main(String[] args) { + SparkSession spark = SparkSession + .builder() + .appName("JavaImputerExample") + .getOrCreate(); + + // $example on$ + List data = Arrays.asList( + RowFactory.create(1.0, Double.NaN), + RowFactory.create(2.0, Double.NaN), + RowFactory.create(Double.NaN, 3.0), + RowFactory.create(4.0, 4.0), + RowFactory.create(5.0, 5.0) + ); + StructType schema = new StructType(new StructField[]{ + createStructField("a", DoubleType, false), + createStructField("b", DoubleType, false) + }); + Dataset df = spark.createDataFrame(data, schema); + + Imputer imputer = new Imputer() + .setInputCols(new String[]{"a", "b"}) + .setOutputCols(new String[]{"out_a", "out_b"}); + + ImputerModel model = imputer.fit(df); + model.transform(df).show(); + // $example off$ + + spark.stop(); + } +} diff --git a/examples/src/main/python/ml/imputer_example.py b/examples/src/main/python/ml/imputer_example.py new file mode 100644 index 000000000000..b8437f827e56 --- /dev/null +++ b/examples/src/main/python/ml/imputer_example.py @@ -0,0 +1,50 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# $example on$ +from pyspark.ml.feature import Imputer +# $example off$ +from pyspark.sql import SparkSession + +""" +An example demonstrating Imputer. +Run with: + bin/spark-submit examples/src/main/python/ml/imputer_example.py +""" + +if __name__ == "__main__": + spark = SparkSession\ + .builder\ + .appName("ImputerExample")\ + .getOrCreate() + + # $example on$ + df = spark.createDataFrame([ + (1.0, float("nan")), + (2.0, float("nan")), + (float("nan"), 3.0), + (4.0, 4.0), + (5.0, 5.0) + ], ["a", "b"]) + + imputer = Imputer(inputCols=["a", "b"], outputCols=["out_a", "out_b"]) + model = imputer.fit(df) + + model.transform(df).show() + # $example off$ + + spark.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/ImputerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/ImputerExample.scala new file mode 100644 index 000000000000..49e98d0c622c --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/ImputerExample.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.Imputer +// $example off$ +import org.apache.spark.sql.SparkSession + +/** + * An example demonstrating Imputer. + * Run with: + * bin/run-example ml.ImputerExample + */ +object ImputerExample { + + def main(args: Array[String]): Unit = { + val spark = SparkSession.builder + .appName("ImputerExample") + .getOrCreate() + + // $example on$ + val df = spark.createDataFrame(Seq( + (1.0, Double.NaN), + (2.0, Double.NaN), + (Double.NaN, 3.0), + (4.0, 4.0), + (5.0, 5.0) + )).toDF("a", "b") + + val imputer = new Imputer() + .setInputCols(Array("a", "b")) + .setOutputCols(Array("out_a", "out_b")) + + val model = imputer.fit(df) + model.transform(df).show() + // $example off$ + + spark.stop() + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala index ec4c6ad75ee2..a41bd8e689d5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala @@ -35,7 +35,7 @@ import org.apache.spark.sql.types._ private[feature] trait ImputerParams extends Params with HasInputCols { /** - * The imputation strategy. + * The imputation strategy. Currently only "mean" and "median" are supported. * If "mean", then replace missing values using the mean value of the feature. * If "median", then replace missing values using the approximate median value of the feature. * Default: mean From 4fa1a43af6b5a6abaef7e04cacb2617a2e92d816 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Mon, 3 Apr 2017 17:44:39 +0800 Subject: [PATCH 195/512] [SPARK-19641][SQL] JSON schema inference in DROPMALFORMED mode produces incorrect schema for non-array/object JSONs ## What changes were proposed in this pull request? Currently, when we infer the types for vaild JSON strings but object or array, we are producing empty schemas regardless of parse modes as below: ```scala scala> spark.read.option("mode", "DROPMALFORMED").json(Seq("""{"a": 1}""", """"a"""").toDS).printSchema() root ``` ```scala scala> spark.read.option("mode", "FAILFAST").json(Seq("""{"a": 1}""", """"a"""").toDS).printSchema() root ``` This PR proposes to handle parse modes in type inference. After this PR, ```scala scala> spark.read.option("mode", "DROPMALFORMED").json(Seq("""{"a": 1}""", """"a"""").toDS).printSchema() root |-- a: long (nullable = true) ``` ``` scala> spark.read.option("mode", "FAILFAST").json(Seq("""{"a": 1}""", """"a"""").toDS).printSchema() java.lang.RuntimeException: Failed to infer a common schema. Struct types are expected but string was found. ``` This PR is based on https://github.com/NathanHowell/spark/commit/e233fd03346a73b3b447fa4c24f3b12c8b2e53ae and I and NathanHowell talked about this in https://issues.apache.org/jira/browse/SPARK-19641 ## How was this patch tested? Unit tests in `JsonSuite` for both `DROPMALFORMED` and `FAILFAST` modes. Author: hyukjinkwon Closes #17492 from HyukjinKwon/SPARK-19641. --- .../datasources/json/JsonInferSchema.scala | 77 +++++++++++-------- .../datasources/json/JsonSuite.scala | 34 +++++++- 2 files changed, 78 insertions(+), 33 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala index e15c30b4374b..fb632cf2bb70 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala @@ -25,7 +25,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.analysis.TypeCoercion import org.apache.spark.sql.catalyst.json.JacksonUtils.nextUntil import org.apache.spark.sql.catalyst.json.JSONOptions -import org.apache.spark.sql.catalyst.util.PermissiveMode +import org.apache.spark.sql.catalyst.util.{DropMalformedMode, FailFastMode, ParseMode, PermissiveMode} import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -41,7 +41,7 @@ private[sql] object JsonInferSchema { json: RDD[T], configOptions: JSONOptions, createParser: (JsonFactory, T) => JsonParser): StructType = { - val shouldHandleCorruptRecord = configOptions.parseMode == PermissiveMode + val parseMode = configOptions.parseMode val columnNameOfCorruptRecord = configOptions.columnNameOfCorruptRecord // perform schema inference on each row and merge afterwards @@ -55,20 +55,24 @@ private[sql] object JsonInferSchema { Some(inferField(parser, configOptions)) } } catch { - case _: JsonParseException if shouldHandleCorruptRecord => - Some(StructType(Seq(StructField(columnNameOfCorruptRecord, StringType)))) - case _: JsonParseException => - None + case e @ (_: RuntimeException | _: JsonProcessingException) => parseMode match { + case PermissiveMode => + Some(StructType(Seq(StructField(columnNameOfCorruptRecord, StringType)))) + case DropMalformedMode => + None + case FailFastMode => + throw e + } } } - }.fold(StructType(Seq()))( - compatibleRootType(columnNameOfCorruptRecord, shouldHandleCorruptRecord)) + }.fold(StructType(Nil))( + compatibleRootType(columnNameOfCorruptRecord, parseMode)) canonicalizeType(rootType) match { case Some(st: StructType) => st case _ => // canonicalizeType erases all empty structs, including the only one we want to keep - StructType(Seq()) + StructType(Nil) } } @@ -202,19 +206,33 @@ private[sql] object JsonInferSchema { private def withCorruptField( struct: StructType, - columnNameOfCorruptRecords: String): StructType = { - if (!struct.fieldNames.contains(columnNameOfCorruptRecords)) { - // If this given struct does not have a column used for corrupt records, - // add this field. - val newFields: Array[StructField] = - StructField(columnNameOfCorruptRecords, StringType, nullable = true) +: struct.fields - // Note: other code relies on this sorting for correctness, so don't remove it! - java.util.Arrays.sort(newFields, structFieldComparator) - StructType(newFields) - } else { - // Otherwise, just return this struct. + other: DataType, + columnNameOfCorruptRecords: String, + parseMode: ParseMode) = parseMode match { + case PermissiveMode => + // If we see any other data type at the root level, we get records that cannot be + // parsed. So, we use the struct as the data type and add the corrupt field to the schema. + if (!struct.fieldNames.contains(columnNameOfCorruptRecords)) { + // If this given struct does not have a column used for corrupt records, + // add this field. + val newFields: Array[StructField] = + StructField(columnNameOfCorruptRecords, StringType, nullable = true) +: struct.fields + // Note: other code relies on this sorting for correctness, so don't remove it! + java.util.Arrays.sort(newFields, structFieldComparator) + StructType(newFields) + } else { + // Otherwise, just return this struct. + struct + } + + case DropMalformedMode => + // If corrupt record handling is disabled we retain the valid schema and discard the other. struct - } + + case FailFastMode => + // If `other` is not struct type, consider it as malformed one and throws an exception. + throw new RuntimeException("Failed to infer a common schema. Struct types are expected" + + s" but ${other.catalogString} was found.") } /** @@ -222,21 +240,20 @@ private[sql] object JsonInferSchema { */ private def compatibleRootType( columnNameOfCorruptRecords: String, - shouldHandleCorruptRecord: Boolean): (DataType, DataType) => DataType = { + parseMode: ParseMode): (DataType, DataType) => DataType = { // Since we support array of json objects at the top level, // we need to check the element type and find the root level data type. case (ArrayType(ty1, _), ty2) => - compatibleRootType(columnNameOfCorruptRecords, shouldHandleCorruptRecord)(ty1, ty2) + compatibleRootType(columnNameOfCorruptRecords, parseMode)(ty1, ty2) case (ty1, ArrayType(ty2, _)) => - compatibleRootType(columnNameOfCorruptRecords, shouldHandleCorruptRecord)(ty1, ty2) - // If we see any other data type at the root level, we get records that cannot be - // parsed. So, we use the struct as the data type and add the corrupt field to the schema. + compatibleRootType(columnNameOfCorruptRecords, parseMode)(ty1, ty2) + // Discard null/empty documents case (struct: StructType, NullType) => struct case (NullType, struct: StructType) => struct - case (struct: StructType, o) if !o.isInstanceOf[StructType] && shouldHandleCorruptRecord => - withCorruptField(struct, columnNameOfCorruptRecords) - case (o, struct: StructType) if !o.isInstanceOf[StructType] && shouldHandleCorruptRecord => - withCorruptField(struct, columnNameOfCorruptRecords) + case (struct: StructType, o) if !o.isInstanceOf[StructType] => + withCorruptField(struct, o, columnNameOfCorruptRecords, parseMode) + case (o, struct: StructType) if !o.isInstanceOf[StructType] => + withCorruptField(struct, o, columnNameOfCorruptRecords, parseMode) // If we get anything else, we call compatibleType. // Usually, when we reach here, ty1 and ty2 are two StructTypes. case (ty1, ty2) => compatibleType(ty1, ty2) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index b09cef76d2be..2ab03819964b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -1041,7 +1041,6 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { spark.read .option("mode", "FAILFAST") .json(corruptRecords) - .collect() } assert(exceptionOne.getMessage.contains("JsonParseException")) @@ -1082,6 +1081,18 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { assert(jsonDFTwo.schema === schemaTwo) } + test("SPARK-19641: Additional corrupt records: DROPMALFORMED mode") { + val schema = new StructType().add("dummy", StringType) + // `DROPMALFORMED` mode should skip corrupt records + val jsonDF = spark.read + .option("mode", "DROPMALFORMED") + .json(additionalCorruptRecords) + checkAnswer( + jsonDF, + Row("test")) + assert(jsonDF.schema === schema) + } + test("Corrupt records: PERMISSIVE mode, without designated column for malformed records") { val schema = StructType( StructField("a", StringType, true) :: @@ -1882,6 +1893,24 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } } + test("SPARK-19641: Handle multi-line corrupt documents (DROPMALFORMED)") { + withTempPath { dir => + val path = dir.getCanonicalPath + val corruptRecordCount = additionalCorruptRecords.count().toInt + assert(corruptRecordCount === 5) + + additionalCorruptRecords + .toDF("value") + // this is the minimum partition count that avoids hash collisions + .repartition(corruptRecordCount * 4, F.hash($"value")) + .write + .text(path) + + val jsonDF = spark.read.option("wholeFile", true).option("mode", "DROPMALFORMED").json(path) + checkAnswer(jsonDF, Seq(Row("test"))) + } + } + test("SPARK-18352: Handle multi-line corrupt documents (FAILFAST)") { withTempPath { dir => val path = dir.getCanonicalPath @@ -1903,9 +1932,8 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { .option("wholeFile", true) .option("mode", "FAILFAST") .json(path) - .collect() } - assert(exceptionOne.getMessage.contains("Failed to parse a value")) + assert(exceptionOne.getMessage.contains("Failed to infer a common schema")) val exceptionTwo = intercept[SparkException] { spark.read From 703c42c398fefd3f7f60e1c503c4df50251f8dcf Mon Sep 17 00:00:00 2001 From: Adrian Ionescu Date: Mon, 3 Apr 2017 08:48:49 -0700 Subject: [PATCH 196/512] [SPARK-20194] Add support for partition pruning to in-memory catalog ## What changes were proposed in this pull request? This patch implements `listPartitionsByFilter()` for `InMemoryCatalog` and thus resolves an outstanding TODO causing the `PruneFileSourcePartitions` optimizer rule not to apply when "spark.sql.catalogImplementation" is set to "in-memory" (which is the default). The change is straightforward: it extracts the code for further filtering of the list of partitions returned by the metastore's `getPartitionsByFilter()` out from `HiveExternalCatalog` into `ExternalCatalogUtils` and calls this new function from `InMemoryCatalog` on the whole list of partitions. Now that this method is implemented we can always pass the `CatalogTable` to the `DataSource` in `FindDataSourceTable`, so that the latter is resolved to a relation with a `CatalogFileIndex`, which is what the `PruneFileSourcePartitions` rule matches for. ## How was this patch tested? Ran existing tests and added new test for `listPartitionsByFilter` in `ExternalCatalogSuite`, which is subclassed by both `InMemoryCatalogSuite` and `HiveExternalCatalogSuite`. Author: Adrian Ionescu Closes #17510 from adrian-ionescu/InMemoryCatalog. --- .../catalog/ExternalCatalogUtils.scala | 33 +++++++++++++++ .../catalyst/catalog/InMemoryCatalog.scala | 8 ++-- .../catalog/ExternalCatalogSuite.scala | 41 +++++++++++++++++++ .../datasources/DataSourceStrategy.scala | 5 +-- .../spark/sql/hive/HiveExternalCatalog.scala | 33 +++------------ .../spark/sql/hive/client/HiveShim.scala | 2 +- .../sql/hive/HiveExternalCatalogSuite.scala | 8 ---- 7 files changed, 85 insertions(+), 45 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala index a8693dcca539..254eedfe7751 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala @@ -25,6 +25,7 @@ import org.apache.hadoop.util.Shell import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.Resolver import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec +import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, BoundReference, Expression, InterpretedPredicate} object ExternalCatalogUtils { // This duplicates default value of Hive `ConfVars.DEFAULTPARTITIONNAME`, since catalyst doesn't @@ -125,6 +126,38 @@ object ExternalCatalogUtils { } escapePathName(col) + "=" + partitionString } + + def prunePartitionsByFilter( + catalogTable: CatalogTable, + inputPartitions: Seq[CatalogTablePartition], + predicates: Seq[Expression], + defaultTimeZoneId: String): Seq[CatalogTablePartition] = { + if (predicates.isEmpty) { + inputPartitions + } else { + val partitionSchema = catalogTable.partitionSchema + val partitionColumnNames = catalogTable.partitionColumnNames.toSet + + val nonPartitionPruningPredicates = predicates.filterNot { + _.references.map(_.name).toSet.subsetOf(partitionColumnNames) + } + if (nonPartitionPruningPredicates.nonEmpty) { + throw new AnalysisException("Expected only partition pruning predicates: " + + nonPartitionPruningPredicates) + } + + val boundPredicate = + InterpretedPredicate.create(predicates.reduce(And).transform { + case att: AttributeReference => + val index = partitionSchema.indexWhere(_.name == att.name) + BoundReference(index, partitionSchema(index).dataType, nullable = true) + }) + + inputPartitions.filter { p => + boundPredicate(p.toRow(partitionSchema, defaultTimeZoneId)) + } + } + } } object CatalogUtils { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala index cdf618aef97c..9ca1c71d1dcb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala @@ -28,7 +28,7 @@ import org.apache.spark.{SparkConf, SparkException} import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} import org.apache.spark.sql.catalyst.analysis._ -import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils.escapePathName +import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils._ import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.util.StringUtils import org.apache.spark.sql.types.StructType @@ -556,9 +556,9 @@ class InMemoryCatalog( table: String, predicates: Seq[Expression], defaultTimeZoneId: String): Seq[CatalogTablePartition] = { - // TODO: Provide an implementation - throw new UnsupportedOperationException( - "listPartitionsByFilter is not implemented for InMemoryCatalog") + val catalogTable = getTable(db, table) + val allPartitions = listPartitions(db, table) + prunePartitionsByFilter(catalogTable, allPartitions, predicates, defaultTimeZoneId) } // -------------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala index 7820f39d9642..42db4398e507 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.catalog import java.net.URI +import java.util.TimeZone import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path @@ -28,6 +29,8 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} import org.apache.spark.sql.catalyst.analysis.{FunctionAlreadyExistsException, NoSuchDatabaseException, NoSuchFunctionException} import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -436,6 +439,44 @@ abstract class ExternalCatalogSuite extends SparkFunSuite with BeforeAndAfterEac assert(catalog.listPartitions("db2", "tbl2", Some(Map("a" -> "unknown"))).isEmpty) } + test("list partitions by filter") { + val tz = TimeZone.getDefault.getID + val catalog = newBasicCatalog() + + def checkAnswer( + table: CatalogTable, filters: Seq[Expression], expected: Set[CatalogTablePartition]) + : Unit = { + + assertResult(expected.map(_.spec)) { + catalog.listPartitionsByFilter(table.database, table.identifier.identifier, filters, tz) + .map(_.spec).toSet + } + } + + val tbl2 = catalog.getTable("db2", "tbl2") + + checkAnswer(tbl2, Seq.empty, Set(part1, part2)) + checkAnswer(tbl2, Seq('a.int <= 1), Set(part1)) + checkAnswer(tbl2, Seq('a.int === 2), Set.empty) + checkAnswer(tbl2, Seq(In('a.int * 10, Seq(30))), Set(part2)) + checkAnswer(tbl2, Seq(Not(In('a.int, Seq(4)))), Set(part1, part2)) + checkAnswer(tbl2, Seq('a.int === 1, 'b.string === "2"), Set(part1)) + checkAnswer(tbl2, Seq('a.int === 1 && 'b.string === "2"), Set(part1)) + checkAnswer(tbl2, Seq('a.int === 1, 'b.string === "x"), Set.empty) + checkAnswer(tbl2, Seq('a.int === 1 || 'b.string === "x"), Set(part1)) + + intercept[AnalysisException] { + try { + checkAnswer(tbl2, Seq('a.int > 0 && 'col1.int > 0), Set.empty) + } catch { + // HiveExternalCatalog may be the first one to notice and throw an exception, which will + // then be caught and converted to a RuntimeException with a descriptive message. + case ex: RuntimeException if ex.getMessage.contains("MetaException") => + throw new AnalysisException(ex.getMessage) + } + } + } + test("drop partitions") { val catalog = newBasicCatalog() assert(catalogPartitionsEqual(catalog, "db2", "tbl2", Seq(part1, part2))) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index bddf5af23e06..c350d8bcbae9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -217,8 +217,6 @@ class FindDataSourceTable(sparkSession: SparkSession) extends Rule[LogicalPlan] val table = r.tableMeta val qualifiedTableName = QualifiedTableName(table.database, table.identifier.table) val cache = sparkSession.sessionState.catalog.tableRelationCache - val withHiveSupport = - sparkSession.sparkContext.conf.get(StaticSQLConf.CATALOG_IMPLEMENTATION) == "hive" val plan = cache.get(qualifiedTableName, new Callable[LogicalPlan]() { override def call(): LogicalPlan = { @@ -233,8 +231,7 @@ class FindDataSourceTable(sparkSession: SparkSession) extends Rule[LogicalPlan] bucketSpec = table.bucketSpec, className = table.provider.get, options = table.storage.properties ++ pathOption, - // TODO: improve `InMemoryCatalog` and remove this limitation. - catalogTable = if (withHiveSupport) Some(table) else None) + catalogTable = Some(table)) LogicalRelation( dataSource.resolveRelation(checkFilesExist = false), diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index 33b21be37203..f0e35dff57f7 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -35,7 +35,7 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException import org.apache.spark.sql.catalyst.catalog._ -import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils.escapePathName +import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.ColumnStat import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap @@ -1039,37 +1039,14 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat defaultTimeZoneId: String): Seq[CatalogTablePartition] = withClient { val rawTable = getRawTable(db, table) val catalogTable = restoreTableMetadata(rawTable) - val partitionColumnNames = catalogTable.partitionColumnNames.toSet - val nonPartitionPruningPredicates = predicates.filterNot { - _.references.map(_.name).toSet.subsetOf(partitionColumnNames) - } - if (nonPartitionPruningPredicates.nonEmpty) { - sys.error("Expected only partition pruning predicates: " + - predicates.reduceLeft(And)) - } + val partColNameMap = buildLowerCasePartColNameMap(catalogTable) - val partitionSchema = catalogTable.partitionSchema - val partColNameMap = buildLowerCasePartColNameMap(getTable(db, table)) - - if (predicates.nonEmpty) { - val clientPrunedPartitions = client.getPartitionsByFilter(rawTable, predicates).map { part => + val clientPrunedPartitions = + client.getPartitionsByFilter(rawTable, predicates).map { part => part.copy(spec = restorePartitionSpec(part.spec, partColNameMap)) } - val boundPredicate = - InterpretedPredicate.create(predicates.reduce(And).transform { - case att: AttributeReference => - val index = partitionSchema.indexWhere(_.name == att.name) - BoundReference(index, partitionSchema(index).dataType, nullable = true) - }) - clientPrunedPartitions.filter { p => - boundPredicate(p.toRow(partitionSchema, defaultTimeZoneId)) - } - } else { - client.getPartitions(catalogTable).map { part => - part.copy(spec = restorePartitionSpec(part.spec, partColNameMap)) - } - } + prunePartitionsByFilter(catalogTable, clientPrunedPartitions, predicates, defaultTimeZoneId) } // -------------------------------------------------------------------------- diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala index d55c41e5c9f2..2e35f3983948 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala @@ -584,7 +584,7 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { */ def convertFilters(table: Table, filters: Seq[Expression]): String = { // hive varchar is treated as catalyst string, but hive varchar can't be pushed down. - val varcharKeys = table.getPartitionKeys.asScala + lazy val varcharKeys = table.getPartitionKeys.asScala .filter(col => col.getType.startsWith(serdeConstants.VARCHAR_TYPE_NAME) || col.getType.startsWith(serdeConstants.CHAR_TYPE_NAME)) .map(col => col.getName).toSet diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala index 4349f1aa23be..bd54c043c6ec 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala @@ -22,7 +22,6 @@ import org.apache.hadoop.conf.Configuration import org.apache.spark.SparkConf import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog._ -import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.types.StructType @@ -50,13 +49,6 @@ class HiveExternalCatalogSuite extends ExternalCatalogSuite { import utils._ - test("list partitions by filter") { - val catalog = newBasicCatalog() - val selectedPartitions = catalog.listPartitionsByFilter("db2", "tbl2", Seq('a.int === 1), "GMT") - assert(selectedPartitions.length == 1) - assert(selectedPartitions.head.spec == part1.spec) - } - test("SPARK-18647: do not put provider in table properties for Hive serde table") { val catalog = newBasicCatalog() val hiveTable = CatalogTable( From 58c9e6e77ae26345291dd9fce2c57aadcc36f66c Mon Sep 17 00:00:00 2001 From: samelamin Date: Mon, 3 Apr 2017 17:16:31 -0700 Subject: [PATCH 197/512] [SPARK-20145] Fix range case insensitive bug in SQL ## What changes were proposed in this pull request? Range in SQL should be case insensitive ## How was this patch tested? unit test Author: samelamin Author: samelamin Closes #17487 from samelamin/SPARK-20145. --- .../ResolveTableValuedFunctions.scala | 4 +--- .../inputs/table-valued-functions.sql | 6 ++++++ .../results/table-valued-functions.sql.out | 20 ++++++++++++++++++- 3 files changed, 26 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala index 6b3bb68538dd..8841309939c2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala @@ -17,9 +17,7 @@ package org.apache.spark.sql.catalyst.analysis -import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.sql.catalyst.expressions.Expression -import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Range} import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.types.{DataType, IntegerType, LongType} @@ -105,7 +103,7 @@ object ResolveTableValuedFunctions extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case u: UnresolvedTableValuedFunction if u.functionArgs.forall(_.resolved) => - builtinFunctions.get(u.functionName) match { + builtinFunctions.get(u.functionName.toLowerCase()) match { case Some(tvf) => val resolved = tvf.flatMap { case (argList, resolver) => argList.implicitCast(u.functionArgs) match { diff --git a/sql/core/src/test/resources/sql-tests/inputs/table-valued-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/table-valued-functions.sql index 2e6dcd538b7a..d0d2df7b243d 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/table-valued-functions.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/table-valued-functions.sql @@ -18,3 +18,9 @@ select * from range(1, 1, 1, 1, 1); -- range call with null select * from range(1, null); + +-- range call with a mixed-case function name +select * from RaNgE(2); + +-- Explain +EXPLAIN select * from RaNgE(2); diff --git a/sql/core/src/test/resources/sql-tests/results/table-valued-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/table-valued-functions.sql.out index d769bcef0aca..acd4ecf14617 100644 --- a/sql/core/src/test/resources/sql-tests/results/table-valued-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/table-valued-functions.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 7 +-- Number of queries: 9 -- !query 0 @@ -85,3 +85,21 @@ struct<> -- !query 6 output java.lang.IllegalArgumentException Invalid arguments for resolved function: 1, null + + +-- !query 7 +select * from RaNgE(2) +-- !query 7 schema +struct +-- !query 7 output +0 +1 + + +-- !query 8 +EXPLAIN select * from RaNgE(2) +-- !query 8 schema +struct +-- !query 8 output +== Physical Plan == +*Range (0, 2, step=1, splits=None) From e7877fd4728ed41e440d7c4d8b6b02bd0d9e873e Mon Sep 17 00:00:00 2001 From: Ron Hu Date: Mon, 3 Apr 2017 17:27:12 -0700 Subject: [PATCH 198/512] [SPARK-19408][SQL] filter estimation on two columns of same table ## What changes were proposed in this pull request? In SQL queries, we also see predicate expressions involving two columns such as "column-1 (op) column-2" where column-1 and column-2 belong to same table. Note that, if column-1 and column-2 belong to different tables, then it is a join operator's work, NOT a filter operator's work. This PR estimates filter selectivity on two columns of same table. For example, multiple tpc-h queries have this predicate "WHERE l_commitdate < l_receiptdate" ## How was this patch tested? We added 6 new test cases to test various logical predicates involving two columns of same table. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Ron Hu Author: U-CHINA\r00754707 Closes #17415 from ron8hu/filterTwoColumns. --- .../statsEstimation/FilterEstimation.scala | 233 +++++++++++++++++- .../FilterEstimationSuite.scala | 140 ++++++++++- 2 files changed, 363 insertions(+), 10 deletions(-) mode change 100644 => 100755 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala mode change 100644 => 100755 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala old mode 100644 new mode 100755 index b32374c5742e..03c76cd41d81 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala @@ -201,6 +201,21 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo case IsNotNull(ar: Attribute) if plan.child.isInstanceOf[LeafNode] => evaluateNullCheck(ar, isNull = false, update) + case op @ Equality(attrLeft: Attribute, attrRight: Attribute) => + evaluateBinaryForTwoColumns(op, attrLeft, attrRight, update) + + case op @ LessThan(attrLeft: Attribute, attrRight: Attribute) => + evaluateBinaryForTwoColumns(op, attrLeft, attrRight, update) + + case op @ LessThanOrEqual(attrLeft: Attribute, attrRight: Attribute) => + evaluateBinaryForTwoColumns(op, attrLeft, attrRight, update) + + case op @ GreaterThan(attrLeft: Attribute, attrRight: Attribute) => + evaluateBinaryForTwoColumns(op, attrLeft, attrRight, update) + + case op @ GreaterThanOrEqual(attrLeft: Attribute, attrRight: Attribute) => + evaluateBinaryForTwoColumns(op, attrLeft, attrRight, update) + case _ => // TODO: it's difficult to support string operators without advanced statistics. // Hence, these string operators Like(_, _) | Contains(_, _) | StartsWith(_, _) @@ -257,7 +272,7 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo /** * Returns a percentage of rows meeting a binary comparison expression. * - * @param op a binary comparison operator uch as =, <, <=, >, >= + * @param op a binary comparison operator such as =, <, <=, >, >= * @param attr an Attribute (or a column) * @param literal a literal value (or constant) * @param update a boolean flag to specify if we need to update ColumnStat of a given column @@ -448,7 +463,7 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo * Returns a percentage of rows meeting a binary comparison expression. * This method evaluate expression for Numeric/Date/Timestamp/Boolean columns. * - * @param op a binary comparison operator uch as =, <, <=, >, >= + * @param op a binary comparison operator such as =, <, <=, >, >= * @param attr an Attribute (or a column) * @param literal a literal value (or constant) * @param update a boolean flag to specify if we need to update ColumnStat of a given column @@ -550,6 +565,220 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo Some(percent.toDouble) } + /** + * Returns a percentage of rows meeting a binary comparison expression containing two columns. + * In SQL queries, we also see predicate expressions involving two columns + * such as "column-1 (op) column-2" where column-1 and column-2 belong to same table. + * Note that, if column-1 and column-2 belong to different tables, then it is a join + * operator's work, NOT a filter operator's work. + * + * @param op a binary comparison operator, including =, <=>, <, <=, >, >= + * @param attrLeft the left Attribute (or a column) + * @param attrRight the right Attribute (or a column) + * @param update a boolean flag to specify if we need to update ColumnStat of the given columns + * for subsequent conditions + * @return an optional double value to show the percentage of rows meeting a given condition + */ + def evaluateBinaryForTwoColumns( + op: BinaryComparison, + attrLeft: Attribute, + attrRight: Attribute, + update: Boolean): Option[Double] = { + + if (!colStatsMap.contains(attrLeft)) { + logDebug("[CBO] No statistics for " + attrLeft) + return None + } + if (!colStatsMap.contains(attrRight)) { + logDebug("[CBO] No statistics for " + attrRight) + return None + } + + attrLeft.dataType match { + case StringType | BinaryType => + // TODO: It is difficult to support other binary comparisons for String/Binary + // type without min/max and advanced statistics like histogram. + logDebug("[CBO] No range comparison statistics for String/Binary type " + attrLeft) + return None + case _ => + } + + val colStatLeft = colStatsMap(attrLeft) + val statsRangeLeft = Range(colStatLeft.min, colStatLeft.max, attrLeft.dataType) + .asInstanceOf[NumericRange] + val maxLeft = BigDecimal(statsRangeLeft.max) + val minLeft = BigDecimal(statsRangeLeft.min) + + val colStatRight = colStatsMap(attrRight) + val statsRangeRight = Range(colStatRight.min, colStatRight.max, attrRight.dataType) + .asInstanceOf[NumericRange] + val maxRight = BigDecimal(statsRangeRight.max) + val minRight = BigDecimal(statsRangeRight.min) + + // determine the overlapping degree between predicate range and column's range + val allNotNull = (colStatLeft.nullCount == 0) && (colStatRight.nullCount == 0) + val (noOverlap: Boolean, completeOverlap: Boolean) = op match { + // Left < Right or Left <= Right + // - no overlap: + // minRight maxRight minLeft maxLeft + // --------+------------------+------------+-------------+-------> + // - complete overlap: (If null values exists, we set it to partial overlap.) + // minLeft maxLeft minRight maxRight + // --------+------------------+------------+-------------+-------> + case _: LessThan => + (minLeft >= maxRight, (maxLeft < minRight) && allNotNull) + case _: LessThanOrEqual => + (minLeft > maxRight, (maxLeft <= minRight) && allNotNull) + + // Left > Right or Left >= Right + // - no overlap: + // minLeft maxLeft minRight maxRight + // --------+------------------+------------+-------------+-------> + // - complete overlap: (If null values exists, we set it to partial overlap.) + // minRight maxRight minLeft maxLeft + // --------+------------------+------------+-------------+-------> + case _: GreaterThan => + (maxLeft <= minRight, (minLeft > maxRight) && allNotNull) + case _: GreaterThanOrEqual => + (maxLeft < minRight, (minLeft >= maxRight) && allNotNull) + + // Left = Right or Left <=> Right + // - no overlap: + // minLeft maxLeft minRight maxRight + // --------+------------------+------------+-------------+-------> + // minRight maxRight minLeft maxLeft + // --------+------------------+------------+-------------+-------> + // - complete overlap: + // minLeft maxLeft + // minRight maxRight + // --------+------------------+-------> + case _: EqualTo => + ((maxLeft < minRight) || (maxRight < minLeft), + (minLeft == minRight) && (maxLeft == maxRight) && allNotNull + && (colStatLeft.distinctCount == colStatRight.distinctCount) + ) + case _: EqualNullSafe => + // For null-safe equality, we use a very restrictive condition to evaluate its overlap. + // If null values exists, we set it to partial overlap. + (((maxLeft < minRight) || (maxRight < minLeft)) && allNotNull, + (minLeft == minRight) && (maxLeft == maxRight) && allNotNull + && (colStatLeft.distinctCount == colStatRight.distinctCount) + ) + } + + var percent = BigDecimal(1.0) + if (noOverlap) { + percent = 0.0 + } else if (completeOverlap) { + percent = 1.0 + } else { + // For partial overlap, we use an empirical value 1/3 as suggested by the book + // "Database Systems, the complete book". + percent = 1.0 / 3.0 + + if (update) { + // Need to adjust new min/max after the filter condition is applied + + val ndvLeft = BigDecimal(colStatLeft.distinctCount) + var newNdvLeft = (ndvLeft * percent).setScale(0, RoundingMode.HALF_UP).toBigInt() + if (newNdvLeft < 1) newNdvLeft = 1 + val ndvRight = BigDecimal(colStatRight.distinctCount) + var newNdvRight = (ndvRight * percent).setScale(0, RoundingMode.HALF_UP).toBigInt() + if (newNdvRight < 1) newNdvRight = 1 + + var newMaxLeft = colStatLeft.max + var newMinLeft = colStatLeft.min + var newMaxRight = colStatRight.max + var newMinRight = colStatRight.min + + op match { + case _: LessThan | _: LessThanOrEqual => + // the left side should be less than the right side. + // If not, we need to adjust it to narrow the range. + // Left < Right or Left <= Right + // minRight < minLeft + // --------+******************+-------> + // filtered ^ + // | + // newMinRight + // + // maxRight < maxLeft + // --------+******************+-------> + // ^ filtered + // | + // newMaxLeft + if (minLeft > minRight) newMinRight = colStatLeft.min + if (maxLeft > maxRight) newMaxLeft = colStatRight.max + + case _: GreaterThan | _: GreaterThanOrEqual => + // the left side should be greater than the right side. + // If not, we need to adjust it to narrow the range. + // Left > Right or Left >= Right + // minLeft < minRight + // --------+******************+-------> + // filtered ^ + // | + // newMinLeft + // + // maxLeft < maxRight + // --------+******************+-------> + // ^ filtered + // | + // newMaxRight + if (minLeft < minRight) newMinLeft = colStatRight.min + if (maxLeft < maxRight) newMaxRight = colStatLeft.max + + case _: EqualTo | _: EqualNullSafe => + // need to set new min to the larger min value, and + // set the new max to the smaller max value. + // Left = Right or Left <=> Right + // minLeft < minRight + // --------+******************+-------> + // filtered ^ + // | + // newMinLeft + // + // minRight <= minLeft + // --------+******************+-------> + // filtered ^ + // | + // newMinRight + // + // maxLeft < maxRight + // --------+******************+-------> + // ^ filtered + // | + // newMaxRight + // + // maxRight <= maxLeft + // --------+******************+-------> + // ^ filtered + // | + // newMaxLeft + if (minLeft < minRight) { + newMinLeft = colStatRight.min + } else { + newMinRight = colStatLeft.min + } + if (maxLeft < maxRight) { + newMaxRight = colStatLeft.max + } else { + newMaxLeft = colStatRight.max + } + } + + val newStatsLeft = colStatLeft.copy(distinctCount = newNdvLeft, min = newMinLeft, + max = newMaxLeft) + colStatsMap(attrLeft) = newStatsLeft + val newStatsRight = colStatRight.copy(distinctCount = newNdvRight, min = newMinRight, + max = newMaxRight) + colStatsMap(attrRight) = newStatsRight + } + } + + Some(percent.toDouble) + } + } class ColumnStatsMap { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala old mode 100644 new mode 100755 index 1966c96c0529..cffb0d873928 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala @@ -33,49 +33,74 @@ import org.apache.spark.sql.types._ class FilterEstimationSuite extends StatsEstimationTestBase { // Suppose our test table has 10 rows and 6 columns. - // First column cint has values: 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 + // column cint has values: 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 // Hence, distinctCount:10, min:1, max:10, nullCount:0, avgLen:4, maxLen:4 val attrInt = AttributeReference("cint", IntegerType)() val colStatInt = ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), nullCount = 0, avgLen = 4, maxLen = 4) - // only 2 values + // column cbool has only 2 distinct values val attrBool = AttributeReference("cbool", BooleanType)() val colStatBool = ColumnStat(distinctCount = 2, min = Some(false), max = Some(true), nullCount = 0, avgLen = 1, maxLen = 1) - // Second column cdate has 10 values from 2017-01-01 through 2017-01-10. + // column cdate has 10 values from 2017-01-01 through 2017-01-10. val dMin = Date.valueOf("2017-01-01") val dMax = Date.valueOf("2017-01-10") val attrDate = AttributeReference("cdate", DateType)() val colStatDate = ColumnStat(distinctCount = 10, min = Some(dMin), max = Some(dMax), nullCount = 0, avgLen = 4, maxLen = 4) - // Fourth column cdecimal has 4 values from 0.20 through 0.80 at increment of 0.20. + // column cdecimal has 4 values from 0.20 through 0.80 at increment of 0.20. val decMin = new java.math.BigDecimal("0.200000000000000000") val decMax = new java.math.BigDecimal("0.800000000000000000") val attrDecimal = AttributeReference("cdecimal", DecimalType(18, 18))() val colStatDecimal = ColumnStat(distinctCount = 4, min = Some(decMin), max = Some(decMax), nullCount = 0, avgLen = 8, maxLen = 8) - // Fifth column cdouble has 10 double values: 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0 + // column cdouble has 10 double values: 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0 val attrDouble = AttributeReference("cdouble", DoubleType)() val colStatDouble = ColumnStat(distinctCount = 10, min = Some(1.0), max = Some(10.0), nullCount = 0, avgLen = 8, maxLen = 8) - // Sixth column cstring has 10 String values: + // column cstring has 10 String values: // "A0", "A1", "A2", "A3", "A4", "A5", "A6", "A7", "A8", "A9" val attrString = AttributeReference("cstring", StringType)() val colStatString = ColumnStat(distinctCount = 10, min = None, max = None, nullCount = 0, avgLen = 2, maxLen = 2) + // column cint2 has values: 7, 8, 9, 10, 11, 12, 13, 14, 15, 16 + // Hence, distinctCount:10, min:7, max:16, nullCount:0, avgLen:4, maxLen:4 + // This column is created to test "cint < cint2 + val attrInt2 = AttributeReference("cint2", IntegerType)() + val colStatInt2 = ColumnStat(distinctCount = 10, min = Some(7), max = Some(16), + nullCount = 0, avgLen = 4, maxLen = 4) + + // column cint3 has values: 30, 31, 32, 33, 34, 35, 36, 37, 38, 39 + // Hence, distinctCount:10, min:30, max:39, nullCount:0, avgLen:4, maxLen:4 + // This column is created to test "cint = cint3 without overlap at all. + val attrInt3 = AttributeReference("cint3", IntegerType)() + val colStatInt3 = ColumnStat(distinctCount = 10, min = Some(30), max = Some(39), + nullCount = 0, avgLen = 4, maxLen = 4) + + // column cint4 has values in the range from 1 to 10 + // distinctCount:10, min:1, max:10, nullCount:0, avgLen:4, maxLen:4 + // This column is created to test complete overlap + val attrInt4 = AttributeReference("cint4", IntegerType)() + val colStatInt4 = ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), + nullCount = 0, avgLen = 4, maxLen = 4) + val attributeMap = AttributeMap(Seq( attrInt -> colStatInt, attrBool -> colStatBool, attrDate -> colStatDate, attrDecimal -> colStatDecimal, attrDouble -> colStatDouble, - attrString -> colStatString)) + attrString -> colStatString, + attrInt2 -> colStatInt2, + attrInt3 -> colStatInt3, + attrInt4 -> colStatInt4 + )) test("true") { validateEstimatedStats( @@ -450,6 +475,89 @@ class FilterEstimationSuite extends StatsEstimationTestBase { } } + test("cint = cint2") { + // partial overlap case + validateEstimatedStats( + Filter(EqualTo(attrInt, attrInt2), childStatsTestPlan(Seq(attrInt, attrInt2), 10L)), + Seq(attrInt -> ColumnStat(distinctCount = 3, min = Some(7), max = Some(10), + nullCount = 0, avgLen = 4, maxLen = 4), + attrInt2 -> ColumnStat(distinctCount = 3, min = Some(7), max = Some(10), + nullCount = 0, avgLen = 4, maxLen = 4)), + expectedRowCount = 4) + } + + test("cint > cint2") { + // partial overlap case + validateEstimatedStats( + Filter(GreaterThan(attrInt, attrInt2), childStatsTestPlan(Seq(attrInt, attrInt2), 10L)), + Seq(attrInt -> ColumnStat(distinctCount = 3, min = Some(7), max = Some(10), + nullCount = 0, avgLen = 4, maxLen = 4), + attrInt2 -> ColumnStat(distinctCount = 3, min = Some(7), max = Some(10), + nullCount = 0, avgLen = 4, maxLen = 4)), + expectedRowCount = 4) + } + + test("cint < cint2") { + // partial overlap case + validateEstimatedStats( + Filter(LessThan(attrInt, attrInt2), childStatsTestPlan(Seq(attrInt, attrInt2), 10L)), + Seq(attrInt -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(10), + nullCount = 0, avgLen = 4, maxLen = 4), + attrInt2 -> ColumnStat(distinctCount = 3, min = Some(7), max = Some(16), + nullCount = 0, avgLen = 4, maxLen = 4)), + expectedRowCount = 4) + } + + test("cint = cint4") { + // complete overlap case + validateEstimatedStats( + Filter(EqualTo(attrInt, attrInt4), childStatsTestPlan(Seq(attrInt, attrInt4), 10L)), + Seq(attrInt -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), + nullCount = 0, avgLen = 4, maxLen = 4), + attrInt4 -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), + nullCount = 0, avgLen = 4, maxLen = 4)), + expectedRowCount = 10) + } + + test("cint < cint4") { + // partial overlap case + validateEstimatedStats( + Filter(LessThan(attrInt, attrInt4), childStatsTestPlan(Seq(attrInt, attrInt4), 10L)), + Seq(attrInt -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(10), + nullCount = 0, avgLen = 4, maxLen = 4), + attrInt4 -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(10), + nullCount = 0, avgLen = 4, maxLen = 4)), + expectedRowCount = 4) + } + + test("cint = cint3") { + // no records qualify due to no overlap + val emptyColStats = Seq[(Attribute, ColumnStat)]() + validateEstimatedStats( + Filter(EqualTo(attrInt, attrInt3), childStatsTestPlan(Seq(attrInt, attrInt3), 10L)), + Nil, // set to empty + expectedRowCount = 0) + } + + test("cint < cint3") { + // all table records qualify. + validateEstimatedStats( + Filter(LessThan(attrInt, attrInt3), childStatsTestPlan(Seq(attrInt, attrInt3), 10L)), + Seq(attrInt -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), + nullCount = 0, avgLen = 4, maxLen = 4), + attrInt3 -> ColumnStat(distinctCount = 10, min = Some(30), max = Some(39), + nullCount = 0, avgLen = 4, maxLen = 4)), + expectedRowCount = 10) + } + + test("cint > cint3") { + // no records qualify due to no overlap + validateEstimatedStats( + Filter(GreaterThan(attrInt, attrInt3), childStatsTestPlan(Seq(attrInt, attrInt3), 10L)), + Nil, // set to empty + expectedRowCount = 0) + } + private def childStatsTestPlan(outList: Seq[Attribute], tableRowCount: BigInt): StatsTestPlan = { StatsTestPlan( outputList = outList, @@ -491,7 +599,23 @@ class FilterEstimationSuite extends StatsEstimationTestBase { sizeInBytes = getOutputSize(filter.output, expectedRowCount, expectedAttributeMap), rowCount = Some(expectedRowCount), attributeStats = expectedAttributeMap) - assert(filter.stats(conf) == expectedStats) + + val filterStats = filter.stats(conf) + assert(filterStats.sizeInBytes == expectedStats.sizeInBytes) + assert(filterStats.rowCount == expectedStats.rowCount) + val rowCountValue = filterStats.rowCount.getOrElse(0) + // check the output column stats if the row count is > 0. + // When row count is 0, the output is set to empty. + if (rowCountValue != 0) { + // Need to check attributeStats one by one because we may have multiple output columns. + // Due to update operation, the output columns may be in different order. + assert(expectedColStats.size == filterStats.attributeStats.size) + expectedColStats.foreach { kv => + val filterColumnStat = filterStats.attributeStats.get(kv._1).get + assert(filterColumnStat == kv._2) + } + } } } + } From 3bfb639cb7352aec572ef6686d3471bd78748ffa Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Tue, 4 Apr 2017 09:53:05 +0900 Subject: [PATCH 199/512] [SPARK-10364][SQL] Support Parquet logical type TIMESTAMP_MILLIS ## What changes were proposed in this pull request? **Description** from JIRA The TimestampType in Spark SQL is of microsecond precision. Ideally, we should convert Spark SQL timestamp values into Parquet TIMESTAMP_MICROS. But unfortunately parquet-mr hasn't supported it yet. For the read path, we should be able to read TIMESTAMP_MILLIS Parquet values and pad a 0 microsecond part to read values. For the write path, currently we are writing timestamps as INT96, similar to Impala and Hive. One alternative is that, we can have a separate SQL option to let users be able to write Spark SQL timestamp values as TIMESTAMP_MILLIS. Of course, in this way the microsecond part will be truncated. ## How was this patch tested? Added new tests in ParquetQuerySuite and ParquetIOSuite Author: Dilip Biswal Closes #15332 from dilipbiswal/parquet-time-millis. --- .../sql/catalyst/util/DateTimeUtils.scala | 19 +++++ .../apache/spark/sql/internal/SQLConf.scala | 9 +++ .../SpecificParquetRecordReaderBase.java | 1 + .../parquet/VectorizedColumnReader.java | 27 ++++++- .../parquet/ParquetFileFormat.scala | 14 +++- .../parquet/ParquetRowConverter.scala | 9 ++- .../parquet/ParquetSchemaConverter.scala | 25 ++++-- .../parquet/ParquetWriteSupport.scala | 15 ++++ .../test-data/timemillis-in-i64.parquet | Bin 0 -> 517 bytes .../datasources/parquet/ParquetIOSuite.scala | 16 +++- .../parquet/ParquetQuerySuite.scala | 73 ++++++++++++++++++ .../parquet/ParquetSchemaSuite.scala | 33 ++++++-- 12 files changed, 221 insertions(+), 20 deletions(-) create mode 100644 sql/core/src/test/resources/test-data/timemillis-in-i64.parquet diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index 9b94c1e2b40b..f614965520f4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -44,6 +44,7 @@ object DateTimeUtils { final val JULIAN_DAY_OF_EPOCH = 2440588 final val SECONDS_PER_DAY = 60 * 60 * 24L final val MICROS_PER_SECOND = 1000L * 1000L + final val MILLIS_PER_SECOND = 1000L final val NANOS_PER_SECOND = MICROS_PER_SECOND * 1000L final val MICROS_PER_DAY = MICROS_PER_SECOND * SECONDS_PER_DAY @@ -237,6 +238,24 @@ object DateTimeUtils { (day.toInt, micros * 1000L) } + /* + * Converts the timestamp to milliseconds since epoch. In spark timestamp values have microseconds + * precision, so this conversion is lossy. + */ + def toMillis(us: SQLTimestamp): Long = { + // When the timestamp is negative i.e before 1970, we need to adjust the millseconds portion. + // Example - 1965-01-01 10:11:12.123456 is represented as (-157700927876544) in micro precision. + // In millis precision the above needs to be represented as (-157700927877). + Math.floor(us.toDouble / MILLIS_PER_SECOND).toLong + } + + /* + * Converts millseconds since epoch to SQLTimestamp. + */ + def fromMillis(millis: Long): SQLTimestamp = { + millis * 1000L + } + /** * Parses a given UTF8 date string to the corresponding a corresponding [[Long]] value. * The return type is [[Option]] in order to distinguish between 0L and null. The following diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 5566b06aa355..06dc0b41204f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -227,6 +227,13 @@ object SQLConf { .booleanConf .createWithDefault(true) + val PARQUET_INT64_AS_TIMESTAMP_MILLIS = buildConf("spark.sql.parquet.int64AsTimestampMillis") + .doc("When true, timestamp values will be stored as INT64 with TIMESTAMP_MILLIS as the " + + "extended type. In this mode, the microsecond portion of the timestamp value will be" + + "truncated.") + .booleanConf + .createWithDefault(false) + val PARQUET_CACHE_METADATA = buildConf("spark.sql.parquet.cacheMetadata") .doc("Turns on caching of Parquet schema metadata. Can speed up querying of static data.") .booleanConf @@ -935,6 +942,8 @@ class SQLConf extends Serializable with Logging { def isParquetINT96AsTimestamp: Boolean = getConf(PARQUET_INT96_AS_TIMESTAMP) + def isParquetINT64AsTimestampMillis: Boolean = getConf(PARQUET_INT64_AS_TIMESTAMP_MILLIS) + def writeLegacyParquetFormat: Boolean = getConf(PARQUET_WRITE_LEGACY_FORMAT) def inMemoryPartitionPruning: Boolean = getConf(IN_MEMORY_PARTITION_PRUNING) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java index bf8717483575..eb97118872ea 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java @@ -197,6 +197,7 @@ protected void initialize(String path, List columns) throws IOException config.set("spark.sql.parquet.binaryAsString", "false"); config.set("spark.sql.parquet.int96AsTimestamp", "false"); config.set("spark.sql.parquet.writeLegacyFormat", "false"); + config.set("spark.sql.parquet.int64AsTimestampMillis", "false"); this.file = new Path(path); long length = this.file.getFileSystem(config).getFileStatus(this.file).getLen(); diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java index cb51cb499eed..9d641b528723 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java @@ -28,6 +28,7 @@ import org.apache.parquet.io.api.Binary; import org.apache.parquet.schema.PrimitiveType; +import org.apache.spark.sql.catalyst.util.DateTimeUtils; import org.apache.spark.sql.execution.vectorized.ColumnVector; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.DecimalType; @@ -155,9 +156,13 @@ void readBatch(int total, ColumnVector column) throws IOException { // Read and decode dictionary ids. defColumn.readIntegers( num, dictionaryIds, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn); + + // Timestamp values encoded as INT64 can't be lazily decoded as we need to post process + // the values to add microseconds precision. if (column.hasDictionary() || (rowId == 0 && (descriptor.getType() == PrimitiveType.PrimitiveTypeName.INT32 || - descriptor.getType() == PrimitiveType.PrimitiveTypeName.INT64 || + (descriptor.getType() == PrimitiveType.PrimitiveTypeName.INT64 && + column.dataType() != DataTypes.TimestampType) || descriptor.getType() == PrimitiveType.PrimitiveTypeName.FLOAT || descriptor.getType() == PrimitiveType.PrimitiveTypeName.DOUBLE || descriptor.getType() == PrimitiveType.PrimitiveTypeName.BINARY))) { @@ -250,7 +255,15 @@ private void decodeDictionaryIds(int rowId, int num, ColumnVector column, column.putLong(i, dictionary.decodeToLong(dictionaryIds.getDictId(i))); } } - } else { + } else if (column.dataType() == DataTypes.TimestampType) { + for (int i = rowId; i < rowId + num; ++i) { + if (!column.isNullAt(i)) { + column.putLong(i, + DateTimeUtils.fromMillis(dictionary.decodeToLong(dictionaryIds.getDictId(i)))); + } + } + } + else { throw new UnsupportedOperationException("Unimplemented type: " + column.dataType()); } break; @@ -362,7 +375,15 @@ private void readLongBatch(int rowId, int num, ColumnVector column) throws IOExc if (column.dataType() == DataTypes.LongType || DecimalType.is64BitDecimalType(column.dataType())) { defColumn.readLongs( - num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn); + num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn); + } else if (column.dataType() == DataTypes.TimestampType) { + for (int i = 0; i < num; i++) { + if (defColumn.readInteger() == maxDefLevel) { + column.putLong(rowId + i, DateTimeUtils.fromMillis(dataColumn.readLong())); + } else { + column.putNull(rowId + i); + } + } } else { throw new UnsupportedOperationException("Unsupported conversion to: " + column.dataType()); } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala index 062aa5c8ea62..2f3a2c62b912 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala @@ -125,6 +125,10 @@ class ParquetFileFormat SQLConf.PARQUET_WRITE_LEGACY_FORMAT.key, sparkSession.sessionState.conf.writeLegacyParquetFormat.toString) + conf.set( + SQLConf.PARQUET_INT64_AS_TIMESTAMP_MILLIS.key, + sparkSession.sessionState.conf.isParquetINT64AsTimestampMillis.toString) + // Sets compression scheme conf.set(ParquetOutputFormat.COMPRESSION, parquetOptions.compressionCodecClassName) @@ -300,6 +304,9 @@ class ParquetFileFormat hadoopConf.setBoolean( SQLConf.PARQUET_INT96_AS_TIMESTAMP.key, sparkSession.sessionState.conf.isParquetINT96AsTimestamp) + hadoopConf.setBoolean( + SQLConf.PARQUET_INT64_AS_TIMESTAMP_MILLIS.key, + sparkSession.sessionState.conf.isParquetINT64AsTimestampMillis) // Try to push down filters when filter push-down is enabled. val pushed = @@ -410,7 +417,8 @@ object ParquetFileFormat extends Logging { val converter = new ParquetSchemaConverter( sparkSession.sessionState.conf.isParquetBinaryAsString, sparkSession.sessionState.conf.isParquetBinaryAsString, - sparkSession.sessionState.conf.writeLegacyParquetFormat) + sparkSession.sessionState.conf.writeLegacyParquetFormat, + sparkSession.sessionState.conf.isParquetINT64AsTimestampMillis) converter.convert(schema) } @@ -510,6 +518,7 @@ object ParquetFileFormat extends Logging { sparkSession: SparkSession): Option[StructType] = { val assumeBinaryIsString = sparkSession.sessionState.conf.isParquetBinaryAsString val assumeInt96IsTimestamp = sparkSession.sessionState.conf.isParquetINT96AsTimestamp + val writeTimestampInMillis = sparkSession.sessionState.conf.isParquetINT64AsTimestampMillis val writeLegacyParquetFormat = sparkSession.sessionState.conf.writeLegacyParquetFormat val serializedConf = new SerializableConfiguration(sparkSession.sessionState.newHadoopConf()) @@ -554,7 +563,8 @@ object ParquetFileFormat extends Logging { new ParquetSchemaConverter( assumeBinaryIsString = assumeBinaryIsString, assumeInt96IsTimestamp = assumeInt96IsTimestamp, - writeLegacyParquetFormat = writeLegacyParquetFormat) + writeLegacyParquetFormat = writeLegacyParquetFormat, + writeTimestampInMillis = writeTimestampInMillis) if (footers.isEmpty) { Iterator.empty diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala index 33dcf2f3fd16..32e6c60cd976 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala @@ -25,7 +25,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.parquet.column.Dictionary import org.apache.parquet.io.api.{Binary, Converter, GroupConverter, PrimitiveConverter} -import org.apache.parquet.schema.{GroupType, MessageType, Type} +import org.apache.parquet.schema.{GroupType, MessageType, OriginalType, Type} import org.apache.parquet.schema.OriginalType.{INT_32, LIST, UTF8} import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.{BINARY, DOUBLE, FIXED_LEN_BYTE_ARRAY, INT32, INT64} @@ -252,6 +252,13 @@ private[parquet] class ParquetRowConverter( case StringType => new ParquetStringConverter(updater) + case TimestampType if parquetType.getOriginalType == OriginalType.TIMESTAMP_MILLIS => + new ParquetPrimitiveConverter(updater) { + override def addLong(value: Long): Unit = { + updater.setLong(DateTimeUtils.fromMillis(value)) + } + } + case TimestampType => // TODO Implements `TIMESTAMP_MICROS` once parquet-mr has that. new ParquetPrimitiveConverter(updater) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala index 66d4027edf9f..0b805e436288 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala @@ -51,22 +51,29 @@ import org.apache.spark.sql.types._ * and prior versions when converting a Catalyst [[StructType]] to a Parquet [[MessageType]]. * When set to false, use standard format defined in parquet-format spec. This argument only * affects Parquet write path. + * @param writeTimestampInMillis Whether to write timestamp values as INT64 annotated by logical + * type TIMESTAMP_MILLIS. + * */ private[parquet] class ParquetSchemaConverter( assumeBinaryIsString: Boolean = SQLConf.PARQUET_BINARY_AS_STRING.defaultValue.get, assumeInt96IsTimestamp: Boolean = SQLConf.PARQUET_INT96_AS_TIMESTAMP.defaultValue.get, - writeLegacyParquetFormat: Boolean = SQLConf.PARQUET_WRITE_LEGACY_FORMAT.defaultValue.get) { + writeLegacyParquetFormat: Boolean = SQLConf.PARQUET_WRITE_LEGACY_FORMAT.defaultValue.get, + writeTimestampInMillis: Boolean = SQLConf.PARQUET_INT64_AS_TIMESTAMP_MILLIS.defaultValue.get) { def this(conf: SQLConf) = this( assumeBinaryIsString = conf.isParquetBinaryAsString, assumeInt96IsTimestamp = conf.isParquetINT96AsTimestamp, - writeLegacyParquetFormat = conf.writeLegacyParquetFormat) + writeLegacyParquetFormat = conf.writeLegacyParquetFormat, + writeTimestampInMillis = conf.isParquetINT64AsTimestampMillis) def this(conf: Configuration) = this( assumeBinaryIsString = conf.get(SQLConf.PARQUET_BINARY_AS_STRING.key).toBoolean, assumeInt96IsTimestamp = conf.get(SQLConf.PARQUET_INT96_AS_TIMESTAMP.key).toBoolean, writeLegacyParquetFormat = conf.get(SQLConf.PARQUET_WRITE_LEGACY_FORMAT.key, - SQLConf.PARQUET_WRITE_LEGACY_FORMAT.defaultValue.get.toString).toBoolean) + SQLConf.PARQUET_WRITE_LEGACY_FORMAT.defaultValue.get.toString).toBoolean, + writeTimestampInMillis = conf.get(SQLConf.PARQUET_INT64_AS_TIMESTAMP_MILLIS.key).toBoolean) + /** * Converts Parquet [[MessageType]] `parquetSchema` to a Spark SQL [[StructType]]. @@ -158,7 +165,7 @@ private[parquet] class ParquetSchemaConverter( case INT_64 | null => LongType case DECIMAL => makeDecimalType(Decimal.MAX_LONG_DIGITS) case UINT_64 => typeNotSupported() - case TIMESTAMP_MILLIS => typeNotImplemented() + case TIMESTAMP_MILLIS => TimestampType case _ => illegalType() } @@ -370,10 +377,16 @@ private[parquet] class ParquetSchemaConverter( // we may resort to microsecond precision in the future. // // For Parquet, we plan to write all `TimestampType` value as `TIMESTAMP_MICROS`, but it's - // currently not implemented yet because parquet-mr 1.7.0 (the version we're currently using) - // hasn't implemented `TIMESTAMP_MICROS` yet. + // currently not implemented yet because parquet-mr 1.8.1 (the version we're currently using) + // hasn't implemented `TIMESTAMP_MICROS` yet, however it supports TIMESTAMP_MILLIS. We will + // encode timestamp values as TIMESTAMP_MILLIS annotating INT64 if + // 'spark.sql.parquet.int64AsTimestampMillis' is set. // // TODO Converts `TIMESTAMP_MICROS` once parquet-mr implements that. + + case TimestampType if writeTimestampInMillis => + Types.primitive(INT64, repetition).as(TIMESTAMP_MILLIS).named(field.name) + case TimestampType => Types.primitive(INT96, repetition).named(field.name) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetWriteSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetWriteSupport.scala index a31d2b9c37e9..38b0e33937f3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetWriteSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetWriteSupport.scala @@ -66,6 +66,9 @@ private[parquet] class ParquetWriteSupport extends WriteSupport[InternalRow] wit // Whether to write data in legacy Parquet format compatible with Spark 1.4 and prior versions private var writeLegacyParquetFormat: Boolean = _ + // Whether to write timestamp value with milliseconds precision. + private var writeTimestampInMillis: Boolean = _ + // Reusable byte array used to write timestamps as Parquet INT96 values private val timestampBuffer = new Array[Byte](12) @@ -80,6 +83,13 @@ private[parquet] class ParquetWriteSupport extends WriteSupport[InternalRow] wit assert(configuration.get(SQLConf.PARQUET_WRITE_LEGACY_FORMAT.key) != null) configuration.get(SQLConf.PARQUET_WRITE_LEGACY_FORMAT.key).toBoolean } + + this.writeTimestampInMillis = { + assert(configuration.get(SQLConf.PARQUET_INT64_AS_TIMESTAMP_MILLIS.key) != null) + configuration.get(SQLConf.PARQUET_INT64_AS_TIMESTAMP_MILLIS.key).toBoolean + } + + this.rootFieldWriters = schema.map(_.dataType).map(makeWriter) val messageType = new ParquetSchemaConverter(configuration).convert(schema) @@ -153,6 +163,11 @@ private[parquet] class ParquetWriteSupport extends WriteSupport[InternalRow] wit recordConsumer.addBinary( Binary.fromReusedByteArray(row.getUTF8String(ordinal).getBytes)) + case TimestampType if writeTimestampInMillis => + (row: SpecializedGetters, ordinal: Int) => + val millis = DateTimeUtils.toMillis(row.getLong(ordinal)) + recordConsumer.addLong(millis) + case TimestampType => (row: SpecializedGetters, ordinal: Int) => { // TODO Writes `TimestampType` values as `TIMESTAMP_MICROS` once parquet-mr implements it diff --git a/sql/core/src/test/resources/test-data/timemillis-in-i64.parquet b/sql/core/src/test/resources/test-data/timemillis-in-i64.parquet new file mode 100644 index 0000000000000000000000000000000000000000..d3c39e2c26eece8d20c154283c1a3fac40859efd GIT binary patch literal 517 zcmaKq&r8EF6vxvVq=*L*6I$q@1U4LW!R~j5m)&*{9h)~%N!wJ5ZP&G_B4Z$)_Gg>! z2h)o=Bros#KJR@4nT)0m0_Y4~*hrOuhBQ;xPQZ2@B3vajb1xuRAvY3%f6_n-nk_eZ z{MSi^;0URPJw7cmmcKn0{wq%yQYBvlIuudDYv%wT8@6HAH4{Oj1~g+UAQh|F!(m;! zKKMIC8>cwLDv)TpLE$eH;x7fSm3sOQyjC!jwBDHKFO+3Wnxh+^v{=Mc8eWuK(0u+u z6E0Z51k<0EM0{qP3`rsK(ig-gVZ`I0Aj5|xNm)`!)w86qE39sXU`ZxZX&J}Ni)B&B z;)2^`-{FdO^DrxK6L_yM4}mci^_Jf literal 0 HcmV?d00001 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala index dbdcd230a4de..57a0af1dda97 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala @@ -107,11 +107,13 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { | required binary g(ENUM); | required binary h(DECIMAL(32,0)); | required fixed_len_byte_array(32) i(DECIMAL(32,0)); + | required int64 j(TIMESTAMP_MILLIS); |} """.stripMargin) val expectedSparkTypes = Seq(ByteType, ShortType, DateType, DecimalType(1, 0), - DecimalType(10, 0), StringType, StringType, DecimalType(32, 0), DecimalType(32, 0)) + DecimalType(10, 0), StringType, StringType, DecimalType(32, 0), DecimalType(32, 0), + TimestampType) withTempPath { location => val path = new Path(location.getCanonicalPath) @@ -607,6 +609,18 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { } } + test("read dictionary and plain encoded timestamp_millis written as INT64") { + ("true" :: "false" :: Nil).foreach { vectorized => + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> vectorized) { + checkAnswer( + // timestamp column in this file is encoded using combination of plain + // and dictionary encodings. + readResourceParquetFile("test-data/timemillis-in-i64.parquet"), + (1 to 3).map(i => Row(new java.sql.Timestamp(10)))) + } + } + } + test("SPARK-12589 copy() on rows returned from reader works for strings") { withTempPath { dir => val data = (1, "abc") ::(2, "helloabcde") :: Nil diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala index 200e356c72fd..c36609586c80 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.datasources.parquet import java.io.File +import java.sql.Timestamp import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.parquet.hadoop.ParquetOutputFormat @@ -162,6 +163,78 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext } } + test("SPARK-10634 timestamp written and read as INT64 - TIMESTAMP_MILLIS") { + val data = (1 to 10).map(i => Row(i, new java.sql.Timestamp(i))) + val schema = StructType(List(StructField("d", IntegerType, false), + StructField("time", TimestampType, false)).toArray) + withSQLConf(SQLConf.PARQUET_INT64_AS_TIMESTAMP_MILLIS.key -> "true") { + withTempPath { file => + val df = spark.createDataFrame(sparkContext.parallelize(data), schema) + df.write.parquet(file.getCanonicalPath) + ("true" :: "false" :: Nil).foreach { vectorized => + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> vectorized) { + val df2 = spark.read.parquet(file.getCanonicalPath) + checkAnswer(df2, df.collect().toSeq) + } + } + } + } + } + + test("SPARK-10634 timestamp written and read as INT64 - truncation") { + withTable("ts") { + sql("create table ts (c1 int, c2 timestamp) using parquet") + sql("insert into ts values (1, '2016-01-01 10:11:12.123456')") + sql("insert into ts values (2, null)") + sql("insert into ts values (3, '1965-01-01 10:11:12.123456')") + checkAnswer( + sql("select * from ts"), + Seq( + Row(1, Timestamp.valueOf("2016-01-01 10:11:12.123456")), + Row(2, null), + Row(3, Timestamp.valueOf("1965-01-01 10:11:12.123456")))) + } + + // The microsecond portion is truncated when written as TIMESTAMP_MILLIS. + withTable("ts") { + withSQLConf(SQLConf.PARQUET_INT64_AS_TIMESTAMP_MILLIS.key -> "true") { + sql("create table ts (c1 int, c2 timestamp) using parquet") + sql("insert into ts values (1, '2016-01-01 10:11:12.123456')") + sql("insert into ts values (2, null)") + sql("insert into ts values (3, '1965-01-01 10:11:12.125456')") + sql("insert into ts values (4, '1965-01-01 10:11:12.125')") + sql("insert into ts values (5, '1965-01-01 10:11:12.1')") + sql("insert into ts values (6, '1965-01-01 10:11:12.123456789')") + sql("insert into ts values (7, '0001-01-01 00:00:00.000000')") + checkAnswer( + sql("select * from ts"), + Seq( + Row(1, Timestamp.valueOf("2016-01-01 10:11:12.123")), + Row(2, null), + Row(3, Timestamp.valueOf("1965-01-01 10:11:12.125")), + Row(4, Timestamp.valueOf("1965-01-01 10:11:12.125")), + Row(5, Timestamp.valueOf("1965-01-01 10:11:12.1")), + Row(6, Timestamp.valueOf("1965-01-01 10:11:12.123")), + Row(7, Timestamp.valueOf("0001-01-01 00:00:00.000")))) + + // Read timestamps that were encoded as TIMESTAMP_MILLIS annotated as INT64 + // with PARQUET_INT64_AS_TIMESTAMP_MILLIS set to false. + withSQLConf(SQLConf.PARQUET_INT64_AS_TIMESTAMP_MILLIS.key -> "false") { + checkAnswer( + sql("select * from ts"), + Seq( + Row(1, Timestamp.valueOf("2016-01-01 10:11:12.123")), + Row(2, null), + Row(3, Timestamp.valueOf("1965-01-01 10:11:12.125")), + Row(4, Timestamp.valueOf("1965-01-01 10:11:12.125")), + Row(5, Timestamp.valueOf("1965-01-01 10:11:12.1")), + Row(6, Timestamp.valueOf("1965-01-01 10:11:12.123")), + Row(7, Timestamp.valueOf("0001-01-01 00:00:00.000")))) + } + } + } + } + test("Enabling/disabling merging partfiles when merging parquet schema") { def testSchemaMerging(expectedColumnNumber: Int): Unit = { withTempDir { dir => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala index 6aa940afbb2c..ce992674d719 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala @@ -53,11 +53,13 @@ abstract class ParquetSchemaTest extends ParquetTest with SharedSQLContext { parquetSchema: String, binaryAsString: Boolean, int96AsTimestamp: Boolean, - writeLegacyParquetFormat: Boolean): Unit = { + writeLegacyParquetFormat: Boolean, + int64AsTimestampMillis: Boolean = false): Unit = { val converter = new ParquetSchemaConverter( assumeBinaryIsString = binaryAsString, assumeInt96IsTimestamp = int96AsTimestamp, - writeLegacyParquetFormat = writeLegacyParquetFormat) + writeLegacyParquetFormat = writeLegacyParquetFormat, + writeTimestampInMillis = int64AsTimestampMillis) test(s"sql <= parquet: $testName") { val actual = converter.convert(MessageTypeParser.parseMessageType(parquetSchema)) @@ -77,11 +79,13 @@ abstract class ParquetSchemaTest extends ParquetTest with SharedSQLContext { parquetSchema: String, binaryAsString: Boolean, int96AsTimestamp: Boolean, - writeLegacyParquetFormat: Boolean): Unit = { + writeLegacyParquetFormat: Boolean, + int64AsTimestampMillis: Boolean = false): Unit = { val converter = new ParquetSchemaConverter( assumeBinaryIsString = binaryAsString, assumeInt96IsTimestamp = int96AsTimestamp, - writeLegacyParquetFormat = writeLegacyParquetFormat) + writeLegacyParquetFormat = writeLegacyParquetFormat, + writeTimestampInMillis = int64AsTimestampMillis) test(s"sql => parquet: $testName") { val actual = converter.convert(sqlSchema) @@ -97,7 +101,8 @@ abstract class ParquetSchemaTest extends ParquetTest with SharedSQLContext { parquetSchema: String, binaryAsString: Boolean, int96AsTimestamp: Boolean, - writeLegacyParquetFormat: Boolean): Unit = { + writeLegacyParquetFormat: Boolean, + int64AsTimestampMillis: Boolean = false): Unit = { testCatalystToParquet( testName, @@ -105,7 +110,8 @@ abstract class ParquetSchemaTest extends ParquetTest with SharedSQLContext { parquetSchema, binaryAsString, int96AsTimestamp, - writeLegacyParquetFormat) + writeLegacyParquetFormat, + int64AsTimestampMillis) testParquetToCatalyst( testName, @@ -113,7 +119,8 @@ abstract class ParquetSchemaTest extends ParquetTest with SharedSQLContext { parquetSchema, binaryAsString, int96AsTimestamp, - writeLegacyParquetFormat) + writeLegacyParquetFormat, + int64AsTimestampMillis) } } @@ -965,6 +972,18 @@ class ParquetSchemaSuite extends ParquetSchemaTest { int96AsTimestamp = true, writeLegacyParquetFormat = true) + testSchema( + "Timestamp written and read as INT64 with TIMESTAMP_MILLIS", + StructType(Seq(StructField("f1", TimestampType))), + """message root { + | optional INT64 f1 (TIMESTAMP_MILLIS); + |} + """.stripMargin, + binaryAsString = true, + int96AsTimestamp = false, + writeLegacyParquetFormat = true, + int64AsTimestampMillis = true) + private def testSchemaClipping( testName: String, parquetSchema: String, From 51d3c854c54369aec1bfd55cefcd080dcd178d5f Mon Sep 17 00:00:00 2001 From: Xiao Li Date: Mon, 3 Apr 2017 23:30:12 -0700 Subject: [PATCH 200/512] [SPARK-20067][SQL] Unify and Clean Up Desc Commands Using Catalog Interface ### What changes were proposed in this pull request? This PR is to unify and clean up the outputs of `DESC EXTENDED/FORMATTED` and `SHOW TABLE EXTENDED` by moving the logics into the Catalog interface. The output formats are improved. We also add the missing attributes. It impacts the DDL commands like `SHOW TABLE EXTENDED`, `DESC EXTENDED` and `DESC FORMATTED`. In addition, by following what we did in Dataset API `printSchema`, we can use `treeString` to show the schema in the more readable way. Below is the current way: ``` Schema: STRUCT<`a`: STRING (nullable = true), `b`: INT (nullable = true), `c`: STRING (nullable = true), `d`: STRING (nullable = true)> ``` After the change, it should look like ``` Schema: root |-- a: string (nullable = true) |-- b: integer (nullable = true) |-- c: string (nullable = true) |-- d: string (nullable = true) ``` ### How was this patch tested? `describe.sql` and `show-tables.sql` Author: Xiao Li Closes #17394 from gatorsmile/descFollowUp. --- .../sql/catalyst/catalog/interface.scala | 136 ++++-- .../spark/sql/execution/SparkSqlParser.scala | 3 +- .../spark/sql/execution/command/tables.scala | 124 ++--- .../resources/sql-tests/inputs/describe.sql | 53 ++- .../sql-tests/results/change-column.sql.out | 9 + .../sql-tests/results/describe.sql.out | 422 ++++++++++++++---- .../sql-tests/results/show-tables.sql.out | 67 +-- .../apache/spark/sql/SQLQueryTestSuite.scala | 19 +- .../sql/execution/SparkSqlParserSuite.scala | 6 +- .../sql/execution/command/DDLSuite.scala | 12 - .../org/apache/spark/sql/jdbc/JDBCSuite.scala | 4 - .../spark/sql/sources/DDLTestSuite.scala | 123 ----- .../spark/sql/sources/DataSourceTest.scala | 56 +++ .../sql/hive/MetastoreDataSourcesSuite.scala | 8 +- .../hive/execution/HiveComparisonTest.scala | 4 +- .../sql/hive/execution/HiveDDLSuite.scala | 93 +--- .../HiveOperatorQueryableSuite.scala | 53 --- .../sql/hive/execution/HiveQuerySuite.scala | 56 --- .../sql/hive/execution/SQLQuerySuite.scala | 131 +----- 19 files changed, 642 insertions(+), 737 deletions(-) delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala delete mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveOperatorQueryableSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala index 70ed44e025f5..3f25f9e7258f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala @@ -20,6 +20,8 @@ package org.apache.spark.sql.catalyst.catalog import java.net.URI import java.util.Date +import scala.collection.mutable + import com.google.common.base.Objects import org.apache.spark.sql.AnalysisException @@ -57,20 +59,25 @@ case class CatalogStorageFormat( properties: Map[String, String]) { override def toString: String = { - val serdePropsToString = CatalogUtils.maskCredentials(properties) match { - case props if props.isEmpty => "" - case props => "Properties: " + props.map(p => p._1 + "=" + p._2).mkString("[", ", ", "]") - } - val output = - Seq(locationUri.map("Location: " + _).getOrElse(""), - inputFormat.map("InputFormat: " + _).getOrElse(""), - outputFormat.map("OutputFormat: " + _).getOrElse(""), - if (compressed) "Compressed" else "", - serde.map("Serde: " + _).getOrElse(""), - serdePropsToString) - output.filter(_.nonEmpty).mkString("Storage(", ", ", ")") + toLinkedHashMap.map { case ((key, value)) => + if (value.isEmpty) key else s"$key: $value" + }.mkString("Storage(", ", ", ")") } + def toLinkedHashMap: mutable.LinkedHashMap[String, String] = { + val map = new mutable.LinkedHashMap[String, String]() + locationUri.foreach(l => map.put("Location", l.toString)) + serde.foreach(map.put("Serde Library", _)) + inputFormat.foreach(map.put("InputFormat", _)) + outputFormat.foreach(map.put("OutputFormat", _)) + if (compressed) map.put("Compressed", "") + CatalogUtils.maskCredentials(properties) match { + case props if props.isEmpty => // No-op + case props => + map.put("Properties", props.map(p => p._1 + "=" + p._2).mkString("[", ", ", "]")) + } + map + } } object CatalogStorageFormat { @@ -91,15 +98,28 @@ case class CatalogTablePartition( storage: CatalogStorageFormat, parameters: Map[String, String] = Map.empty) { - override def toString: String = { + def toLinkedHashMap: mutable.LinkedHashMap[String, String] = { + val map = new mutable.LinkedHashMap[String, String]() val specString = spec.map { case (k, v) => s"$k=$v" }.mkString(", ") - val output = - Seq( - s"Partition Values: [$specString]", - s"$storage", - s"Partition Parameters:{${parameters.map(p => p._1 + "=" + p._2).mkString(", ")}}") + map.put("Partition Values", s"[$specString]") + map ++= storage.toLinkedHashMap + if (parameters.nonEmpty) { + map.put("Partition Parameters", s"{${parameters.map(p => p._1 + "=" + p._2).mkString(", ")}}") + } + map + } - output.filter(_.nonEmpty).mkString("CatalogPartition(\n\t", "\n\t", ")") + override def toString: String = { + toLinkedHashMap.map { case ((key, value)) => + if (value.isEmpty) key else s"$key: $value" + }.mkString("CatalogPartition(\n\t", "\n\t", ")") + } + + /** Readable string representation for the CatalogTablePartition. */ + def simpleString: String = { + toLinkedHashMap.map { case ((key, value)) => + if (value.isEmpty) key else s"$key: $value" + }.mkString("", "\n", "") } /** Return the partition location, assuming it is specified. */ @@ -154,6 +174,14 @@ case class BucketSpec( } s"$numBuckets buckets, $bucketString$sortString" } + + def toLinkedHashMap: mutable.LinkedHashMap[String, String] = { + mutable.LinkedHashMap[String, String]( + "Num Buckets" -> numBuckets.toString, + "Bucket Columns" -> bucketColumnNames.map(quoteIdentifier).mkString("[", ", ", "]"), + "Sort Columns" -> sortColumnNames.map(quoteIdentifier).mkString("[", ", ", "]") + ) + } } /** @@ -261,40 +289,50 @@ case class CatalogTable( locationUri, inputFormat, outputFormat, serde, compressed, properties)) } - override def toString: String = { + + def toLinkedHashMap: mutable.LinkedHashMap[String, String] = { + val map = new mutable.LinkedHashMap[String, String]() val tableProperties = properties.map(p => p._1 + "=" + p._2).mkString("[", ", ", "]") val partitionColumns = partitionColumnNames.map(quoteIdentifier).mkString("[", ", ", "]") - val bucketStrings = bucketSpec match { - case Some(BucketSpec(numBuckets, bucketColumnNames, sortColumnNames)) => - val bucketColumnsString = bucketColumnNames.map(quoteIdentifier).mkString("[", ", ", "]") - val sortColumnsString = sortColumnNames.map(quoteIdentifier).mkString("[", ", ", "]") - Seq( - s"Num Buckets: $numBuckets", - if (bucketColumnNames.nonEmpty) s"Bucket Columns: $bucketColumnsString" else "", - if (sortColumnNames.nonEmpty) s"Sort Columns: $sortColumnsString" else "" - ) - - case _ => Nil + + identifier.database.foreach(map.put("Database", _)) + map.put("Table", identifier.table) + if (owner.nonEmpty) map.put("Owner", owner) + map.put("Created", new Date(createTime).toString) + map.put("Last Access", new Date(lastAccessTime).toString) + map.put("Type", tableType.name) + provider.foreach(map.put("Provider", _)) + bucketSpec.foreach(map ++= _.toLinkedHashMap) + comment.foreach(map.put("Comment", _)) + if (tableType == CatalogTableType.VIEW) { + viewText.foreach(map.put("View Text", _)) + viewDefaultDatabase.foreach(map.put("View Default Database", _)) + if (viewQueryColumnNames.nonEmpty) { + map.put("View Query Output Columns", viewQueryColumnNames.mkString("[", ", ", "]")) + } } - val output = - Seq(s"Table: ${identifier.quotedString}", - if (owner.nonEmpty) s"Owner: $owner" else "", - s"Created: ${new Date(createTime).toString}", - s"Last Access: ${new Date(lastAccessTime).toString}", - s"Type: ${tableType.name}", - if (schema.nonEmpty) s"Schema: ${schema.mkString("[", ", ", "]")}" else "", - if (provider.isDefined) s"Provider: ${provider.get}" else "", - if (partitionColumnNames.nonEmpty) s"Partition Columns: $partitionColumns" else "" - ) ++ bucketStrings ++ Seq( - viewText.map("View: " + _).getOrElse(""), - comment.map("Comment: " + _).getOrElse(""), - if (properties.nonEmpty) s"Properties: $tableProperties" else "", - if (stats.isDefined) s"Statistics: ${stats.get.simpleString}" else "", - s"$storage", - if (tracksPartitionsInCatalog) "Partition Provider: Catalog" else "") - - output.filter(_.nonEmpty).mkString("CatalogTable(\n\t", "\n\t", ")") + if (properties.nonEmpty) map.put("Properties", tableProperties) + stats.foreach(s => map.put("Statistics", s.simpleString)) + map ++= storage.toLinkedHashMap + if (tracksPartitionsInCatalog) map.put("Partition Provider", "Catalog") + if (partitionColumnNames.nonEmpty) map.put("Partition Columns", partitionColumns) + if (schema.nonEmpty) map.put("Schema", schema.treeString) + + map + } + + override def toString: String = { + toLinkedHashMap.map { case ((key, value)) => + if (value.isEmpty) key else s"$key: $value" + }.mkString("CatalogTable(\n", "\n", ")") + } + + /** Readable string representation for the CatalogTable. */ + def simpleString: String = { + toLinkedHashMap.map { case ((key, value)) => + if (value.isEmpty) key else s"$key: $value" + }.mkString("", "\n", "") } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index d4f23f9dd518..80afb59b3e88 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -322,8 +322,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { DescribeTableCommand( visitTableIdentifier(ctx.tableIdentifier), partitionSpec, - ctx.EXTENDED != null, - ctx.FORMATTED != null) + ctx.EXTENDED != null || ctx.FORMATTED != null) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala index c7aeef06a0bf..ebf03e1bf886 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala @@ -500,8 +500,7 @@ case class TruncateTableCommand( case class DescribeTableCommand( table: TableIdentifier, partitionSpec: TablePartitionSpec, - isExtended: Boolean, - isFormatted: Boolean) + isExtended: Boolean) extends RunnableCommand { override val output: Seq[Attribute] = Seq( @@ -536,14 +535,12 @@ case class DescribeTableCommand( describePartitionInfo(metadata, result) - if (partitionSpec.isEmpty) { - if (isExtended) { - describeExtendedTableInfo(metadata, result) - } else if (isFormatted) { - describeFormattedTableInfo(metadata, result) - } - } else { + if (partitionSpec.nonEmpty) { + // Outputs the partition-specific info for the DDL command: + // "DESCRIBE [EXTENDED|FORMATTED] table_name PARTITION (partitionVal*)" describeDetailedPartitionInfo(sparkSession, catalog, metadata, result) + } else if (isExtended) { + describeFormattedTableInfo(metadata, result) } } @@ -553,76 +550,20 @@ case class DescribeTableCommand( private def describePartitionInfo(table: CatalogTable, buffer: ArrayBuffer[Row]): Unit = { if (table.partitionColumnNames.nonEmpty) { append(buffer, "# Partition Information", "", "") - append(buffer, s"# ${output.head.name}", output(1).name, output(2).name) describeSchema(table.partitionSchema, buffer) } } - private def describeExtendedTableInfo(table: CatalogTable, buffer: ArrayBuffer[Row]): Unit = { - append(buffer, "", "", "") - append(buffer, "# Detailed Table Information", table.toString, "") - } - private def describeFormattedTableInfo(table: CatalogTable, buffer: ArrayBuffer[Row]): Unit = { + // The following information has been already shown in the previous outputs + val excludedTableInfo = Seq( + "Partition Columns", + "Schema" + ) append(buffer, "", "", "") append(buffer, "# Detailed Table Information", "", "") - append(buffer, "Database:", table.database, "") - append(buffer, "Owner:", table.owner, "") - append(buffer, "Created:", new Date(table.createTime).toString, "") - append(buffer, "Last Access:", new Date(table.lastAccessTime).toString, "") - append(buffer, "Location:", table.storage.locationUri.map(CatalogUtils.URIToString(_)) - .getOrElse(""), "") - append(buffer, "Table Type:", table.tableType.name, "") - append(buffer, "Comment:", table.comment.getOrElse(""), "") - table.stats.foreach(s => append(buffer, "Statistics:", s.simpleString, "")) - - append(buffer, "Table Parameters:", "", "") - table.properties.foreach { case (key, value) => - append(buffer, s" $key", value, "") - } - - describeStorageInfo(table, buffer) - - if (table.tableType == CatalogTableType.VIEW) describeViewInfo(table, buffer) - - if (DDLUtils.isDatasourceTable(table) && table.tracksPartitionsInCatalog) { - append(buffer, "Partition Provider:", "Catalog", "") - } - } - - private def describeStorageInfo(metadata: CatalogTable, buffer: ArrayBuffer[Row]): Unit = { - append(buffer, "", "", "") - append(buffer, "# Storage Information", "", "") - metadata.storage.serde.foreach(serdeLib => append(buffer, "SerDe Library:", serdeLib, "")) - metadata.storage.inputFormat.foreach(format => append(buffer, "InputFormat:", format, "")) - metadata.storage.outputFormat.foreach(format => append(buffer, "OutputFormat:", format, "")) - append(buffer, "Compressed:", if (metadata.storage.compressed) "Yes" else "No", "") - describeBucketingInfo(metadata, buffer) - - append(buffer, "Storage Desc Parameters:", "", "") - val maskedProperties = CatalogUtils.maskCredentials(metadata.storage.properties) - maskedProperties.foreach { case (key, value) => - append(buffer, s" $key", value, "") - } - } - - private def describeViewInfo(metadata: CatalogTable, buffer: ArrayBuffer[Row]): Unit = { - append(buffer, "", "", "") - append(buffer, "# View Information", "", "") - append(buffer, "View Text:", metadata.viewText.getOrElse(""), "") - append(buffer, "View Default Database:", metadata.viewDefaultDatabase.getOrElse(""), "") - append(buffer, "View Query Output Columns:", - metadata.viewQueryColumnNames.mkString("[", ", ", "]"), "") - } - - private def describeBucketingInfo(metadata: CatalogTable, buffer: ArrayBuffer[Row]): Unit = { - metadata.bucketSpec match { - case Some(BucketSpec(numBuckets, bucketColumnNames, sortColumnNames)) => - append(buffer, "Num Buckets:", numBuckets.toString, "") - append(buffer, "Bucket Columns:", bucketColumnNames.mkString("[", ", ", "]"), "") - append(buffer, "Sort Columns:", sortColumnNames.mkString("[", ", ", "]"), "") - - case _ => + table.toLinkedHashMap.filterKeys(!excludedTableInfo.contains(_)).foreach { + s => append(buffer, s._1, s._2, "") } } @@ -637,21 +578,7 @@ case class DescribeTableCommand( } DDLUtils.verifyPartitionProviderIsHive(spark, metadata, "DESC PARTITION") val partition = catalog.getPartition(table, partitionSpec) - if (isExtended) { - describeExtendedDetailedPartitionInfo(table, metadata, partition, result) - } else if (isFormatted) { - describeFormattedDetailedPartitionInfo(table, metadata, partition, result) - describeStorageInfo(metadata, result) - } - } - - private def describeExtendedDetailedPartitionInfo( - tableIdentifier: TableIdentifier, - table: CatalogTable, - partition: CatalogTablePartition, - buffer: ArrayBuffer[Row]): Unit = { - append(buffer, "", "", "") - append(buffer, "Detailed Partition Information " + partition.toString, "", "") + if (isExtended) describeFormattedDetailedPartitionInfo(table, metadata, partition, result) } private def describeFormattedDetailedPartitionInfo( @@ -661,18 +588,21 @@ case class DescribeTableCommand( buffer: ArrayBuffer[Row]): Unit = { append(buffer, "", "", "") append(buffer, "# Detailed Partition Information", "", "") - append(buffer, "Partition Value:", s"[${partition.spec.values.mkString(", ")}]", "") - append(buffer, "Database:", table.database, "") - append(buffer, "Table:", tableIdentifier.table, "") - append(buffer, "Location:", partition.storage.locationUri.map(CatalogUtils.URIToString(_)) - .getOrElse(""), "") - append(buffer, "Partition Parameters:", "", "") - partition.parameters.foreach { case (key, value) => - append(buffer, s" $key", value, "") + append(buffer, "Database", table.database, "") + append(buffer, "Table", tableIdentifier.table, "") + partition.toLinkedHashMap.foreach(s => append(buffer, s._1, s._2, "")) + append(buffer, "", "", "") + append(buffer, "# Storage Information", "", "") + table.bucketSpec match { + case Some(spec) => + spec.toLinkedHashMap.foreach(s => append(buffer, s._1, s._2, "")) + case _ => } + table.storage.toLinkedHashMap.foreach(s => append(buffer, s._1, s._2, "")) } private def describeSchema(schema: StructType, buffer: ArrayBuffer[Row]): Unit = { + append(buffer, s"# ${output.head.name}", output(1).name, output(2).name) schema.foreach { column => append(buffer, column.name, column.dataType.simpleString, column.getComment().orNull) } @@ -728,7 +658,7 @@ case class ShowTablesCommand( val tableName = tableIdent.table val isTemp = catalog.isTemporaryTable(tableIdent) if (isExtended) { - val information = catalog.getTempViewOrPermanentTableMetadata(tableIdent).toString + val information = catalog.getTempViewOrPermanentTableMetadata(tableIdent).simpleString Row(database, tableName, isTemp, s"$information\n") } else { Row(database, tableName, isTemp) @@ -745,7 +675,7 @@ case class ShowTablesCommand( val database = table.database.getOrElse("") val tableName = table.table val isTemp = catalog.isTemporaryTable(table) - val information = partition.toString + val information = partition.simpleString Seq(Row(database, tableName, isTemp, s"$information\n")) } } diff --git a/sql/core/src/test/resources/sql-tests/inputs/describe.sql b/sql/core/src/test/resources/sql-tests/inputs/describe.sql index 56f3281440d2..6de4cf0d5afa 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/describe.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/describe.sql @@ -1,10 +1,23 @@ -CREATE TABLE t (a STRING, b INT, c STRING, d STRING) USING parquet PARTITIONED BY (c, d) COMMENT 'table_comment'; +CREATE TABLE t (a STRING, b INT, c STRING, d STRING) USING parquet + PARTITIONED BY (c, d) CLUSTERED BY (a) SORTED BY (b ASC) INTO 2 BUCKETS + COMMENT 'table_comment'; + +CREATE TEMPORARY VIEW temp_v AS SELECT * FROM t; + +CREATE TEMPORARY VIEW temp_Data_Source_View + USING org.apache.spark.sql.sources.DDLScanSource + OPTIONS ( + From '1', + To '10', + Table 'test1'); + +CREATE VIEW v AS SELECT * FROM t; ALTER TABLE t ADD PARTITION (c='Us', d=1); DESCRIBE t; -DESC t; +DESC default.t; DESC TABLE t; @@ -27,5 +40,39 @@ DESC t PARTITION (c='Us'); -- ParseException: PARTITION specification is incomplete DESC t PARTITION (c='Us', d); --- DROP TEST TABLE +-- DESC Temp View + +DESC temp_v; + +DESC TABLE temp_v; + +DESC FORMATTED temp_v; + +DESC EXTENDED temp_v; + +DESC temp_Data_Source_View; + +-- AnalysisException DESC PARTITION is not allowed on a temporary view +DESC temp_v PARTITION (c='Us', d=1); + +-- DESC Persistent View + +DESC v; + +DESC TABLE v; + +DESC FORMATTED v; + +DESC EXTENDED v; + +-- AnalysisException DESC PARTITION is not allowed on a view +DESC v PARTITION (c='Us', d=1); + +-- DROP TEST TABLES/VIEWS DROP TABLE t; + +DROP VIEW temp_v; + +DROP VIEW temp_Data_Source_View; + +DROP VIEW v; diff --git a/sql/core/src/test/resources/sql-tests/results/change-column.sql.out b/sql/core/src/test/resources/sql-tests/results/change-column.sql.out index ba8bc936f0c7..678a3f0f0a3c 100644 --- a/sql/core/src/test/resources/sql-tests/results/change-column.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/change-column.sql.out @@ -15,6 +15,7 @@ DESC test_change -- !query 1 schema struct -- !query 1 output +# col_name data_type comment a int b string c int @@ -34,6 +35,7 @@ DESC test_change -- !query 3 schema struct -- !query 3 output +# col_name data_type comment a int b string c int @@ -53,6 +55,7 @@ DESC test_change -- !query 5 schema struct -- !query 5 output +# col_name data_type comment a int b string c int @@ -91,6 +94,7 @@ DESC test_change -- !query 8 schema struct -- !query 8 output +# col_name data_type comment a int b string c int @@ -125,6 +129,7 @@ DESC test_change -- !query 12 schema struct -- !query 12 output +# col_name data_type comment a int this is column a b string #*02?` c int @@ -143,6 +148,7 @@ DESC test_change -- !query 14 schema struct -- !query 14 output +# col_name data_type comment a int this is column a b string #*02?` c int @@ -162,6 +168,7 @@ DESC test_change -- !query 16 schema struct -- !query 16 output +# col_name data_type comment a int this is column a b string #*02?` c int @@ -186,6 +193,7 @@ DESC test_change -- !query 18 schema struct -- !query 18 output +# col_name data_type comment a int this is column a b string #*02?` c int @@ -229,6 +237,7 @@ DESC test_change -- !query 23 schema struct -- !query 23 output +# col_name data_type comment a int this is column A b string #*02?` c int diff --git a/sql/core/src/test/resources/sql-tests/results/describe.sql.out b/sql/core/src/test/resources/sql-tests/results/describe.sql.out index 422d548ea8de..de10b29f3c65 100644 --- a/sql/core/src/test/resources/sql-tests/results/describe.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/describe.sql.out @@ -1,9 +1,11 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 14 +-- Number of queries: 31 -- !query 0 -CREATE TABLE t (a STRING, b INT, c STRING, d STRING) USING parquet PARTITIONED BY (c, d) COMMENT 'table_comment' +CREATE TABLE t (a STRING, b INT, c STRING, d STRING) USING parquet + PARTITIONED BY (c, d) CLUSTERED BY (a) SORTED BY (b ASC) INTO 2 BUCKETS + COMMENT 'table_comment' -- !query 0 schema struct<> -- !query 0 output @@ -11,7 +13,7 @@ struct<> -- !query 1 -ALTER TABLE t ADD PARTITION (c='Us', d=1) +CREATE TEMPORARY VIEW temp_v AS SELECT * FROM t -- !query 1 schema struct<> -- !query 1 output @@ -19,187 +21,239 @@ struct<> -- !query 2 -DESCRIBE t +CREATE TEMPORARY VIEW temp_Data_Source_View + USING org.apache.spark.sql.sources.DDLScanSource + OPTIONS ( + From '1', + To '10', + Table 'test1') -- !query 2 schema -struct +struct<> -- !query 2 output -# Partition Information + + + +-- !query 3 +CREATE VIEW v AS SELECT * FROM t +-- !query 3 schema +struct<> +-- !query 3 output + + + +-- !query 4 +ALTER TABLE t ADD PARTITION (c='Us', d=1) +-- !query 4 schema +struct<> +-- !query 4 output + + + +-- !query 5 +DESCRIBE t +-- !query 5 schema +struct +-- !query 5 output # col_name data_type comment a string b int c string -c string d string +# Partition Information +# col_name data_type comment +c string d string --- !query 3 -DESC t --- !query 3 schema +-- !query 6 +DESC default.t +-- !query 6 schema struct --- !query 3 output -# Partition Information +-- !query 6 output # col_name data_type comment a string b int c string -c string d string +# Partition Information +# col_name data_type comment +c string d string --- !query 4 +-- !query 7 DESC TABLE t --- !query 4 schema +-- !query 7 schema struct --- !query 4 output -# Partition Information +-- !query 7 output # col_name data_type comment a string b int c string -c string d string +# Partition Information +# col_name data_type comment +c string d string --- !query 5 +-- !query 8 DESC FORMATTED t --- !query 5 schema +-- !query 8 schema struct --- !query 5 output -# Detailed Table Information -# Partition Information -# Storage Information +-- !query 8 output # col_name data_type comment -Comment: table_comment -Compressed: No -Created: -Database: default -Last Access: -Location: sql/core/spark-warehouse/t -Owner: -Partition Provider: Catalog -Storage Desc Parameters: -Table Parameters: -Table Type: MANAGED a string b int c string +d string +# Partition Information +# col_name data_type comment c string d string -d string + +# Detailed Table Information +Database default +Table t +Created [not included in comparison] +Last Access [not included in comparison] +Type MANAGED +Provider parquet +Num Buckets 2 +Bucket Columns [`a`] +Sort Columns [`b`] +Comment table_comment +Location [not included in comparison]sql/core/spark-warehouse/t +Partition Provider Catalog --- !query 6 +-- !query 9 DESC EXTENDED t --- !query 6 schema +-- !query 9 schema struct --- !query 6 output -# Detailed Table Information CatalogTable( - Table: `default`.`t` - Created: - Last Access: - Type: MANAGED - Schema: [StructField(a,StringType,true), StructField(b,IntegerType,true), StructField(c,StringType,true), StructField(d,StringType,true)] - Provider: parquet - Partition Columns: [`c`, `d`] - Comment: table_comment - Storage(Location: sql/core/spark-warehouse/t) - Partition Provider: Catalog) -# Partition Information +-- !query 9 output # col_name data_type comment a string b int c string +d string +# Partition Information +# col_name data_type comment c string d string -d string + +# Detailed Table Information +Database default +Table t +Created [not included in comparison] +Last Access [not included in comparison] +Type MANAGED +Provider parquet +Num Buckets 2 +Bucket Columns [`a`] +Sort Columns [`b`] +Comment table_comment +Location [not included in comparison]sql/core/spark-warehouse/t +Partition Provider Catalog --- !query 7 +-- !query 10 DESC t PARTITION (c='Us', d=1) --- !query 7 schema +-- !query 10 schema struct --- !query 7 output -# Partition Information +-- !query 10 output # col_name data_type comment a string b int c string -c string d string +# Partition Information +# col_name data_type comment +c string d string --- !query 8 +-- !query 11 DESC EXTENDED t PARTITION (c='Us', d=1) --- !query 8 schema +-- !query 11 schema struct --- !query 8 output -# Partition Information +-- !query 11 output # col_name data_type comment -Detailed Partition Information CatalogPartition( - Partition Values: [c=Us, d=1] - Storage(Location: sql/core/spark-warehouse/t/c=Us/d=1) - Partition Parameters:{}) a string b int c string +d string +# Partition Information +# col_name data_type comment c string d string -d string + +# Detailed Partition Information +Database default +Table t +Partition Values [c=Us, d=1] +Location [not included in comparison]sql/core/spark-warehouse/t/c=Us/d=1 + +# Storage Information +Num Buckets 2 +Bucket Columns [`a`] +Sort Columns [`b`] +Location [not included in comparison]sql/core/spark-warehouse/t --- !query 9 +-- !query 12 DESC FORMATTED t PARTITION (c='Us', d=1) --- !query 9 schema +-- !query 12 schema struct --- !query 9 output -# Detailed Partition Information -# Partition Information -# Storage Information +-- !query 12 output # col_name data_type comment -Compressed: No -Database: default -Location: sql/core/spark-warehouse/t/c=Us/d=1 -Partition Parameters: -Partition Value: [Us, 1] -Storage Desc Parameters: -Table: t a string b int c string +d string +# Partition Information +# col_name data_type comment c string d string -d string + +# Detailed Partition Information +Database default +Table t +Partition Values [c=Us, d=1] +Location [not included in comparison]sql/core/spark-warehouse/t/c=Us/d=1 + +# Storage Information +Num Buckets 2 +Bucket Columns [`a`] +Sort Columns [`b`] +Location [not included in comparison]sql/core/spark-warehouse/t --- !query 10 +-- !query 13 DESC t PARTITION (c='Us', d=2) --- !query 10 schema +-- !query 13 schema struct<> --- !query 10 output +-- !query 13 output org.apache.spark.sql.catalyst.analysis.NoSuchPartitionException Partition not found in table 't' database 'default': c -> Us d -> 2; --- !query 11 +-- !query 14 DESC t PARTITION (c='Us') --- !query 11 schema +-- !query 14 schema struct<> --- !query 11 output +-- !query 14 output org.apache.spark.sql.AnalysisException Partition spec is invalid. The spec (c) must match the partition spec (c, d) defined in table '`default`.`t`'; --- !query 12 +-- !query 15 DESC t PARTITION (c='Us', d) --- !query 12 schema +-- !query 15 schema struct<> --- !query 12 output +-- !query 15 output org.apache.spark.sql.catalyst.parser.ParseException PARTITION specification is incomplete: `d`(line 1, pos 0) @@ -209,9 +263,193 @@ DESC t PARTITION (c='Us', d) ^^^ --- !query 13 +-- !query 16 +DESC temp_v +-- !query 16 schema +struct +-- !query 16 output +# col_name data_type comment +a string +b int +c string +d string + + +-- !query 17 +DESC TABLE temp_v +-- !query 17 schema +struct +-- !query 17 output +# col_name data_type comment +a string +b int +c string +d string + + +-- !query 18 +DESC FORMATTED temp_v +-- !query 18 schema +struct +-- !query 18 output +# col_name data_type comment +a string +b int +c string +d string + + +-- !query 19 +DESC EXTENDED temp_v +-- !query 19 schema +struct +-- !query 19 output +# col_name data_type comment +a string +b int +c string +d string + + +-- !query 20 +DESC temp_Data_Source_View +-- !query 20 schema +struct +-- !query 20 output +# col_name data_type comment +intType int test comment test1 +stringType string +dateType date +timestampType timestamp +doubleType double +bigintType bigint +tinyintType tinyint +decimalType decimal(10,0) +fixedDecimalType decimal(5,1) +binaryType binary +booleanType boolean +smallIntType smallint +floatType float +mapType map +arrayType array +structType struct + + +-- !query 21 +DESC temp_v PARTITION (c='Us', d=1) +-- !query 21 schema +struct<> +-- !query 21 output +org.apache.spark.sql.AnalysisException +DESC PARTITION is not allowed on a temporary view: temp_v; + + +-- !query 22 +DESC v +-- !query 22 schema +struct +-- !query 22 output +# col_name data_type comment +a string +b int +c string +d string + + +-- !query 23 +DESC TABLE v +-- !query 23 schema +struct +-- !query 23 output +# col_name data_type comment +a string +b int +c string +d string + + +-- !query 24 +DESC FORMATTED v +-- !query 24 schema +struct +-- !query 24 output +# col_name data_type comment +a string +b int +c string +d string + +# Detailed Table Information +Database default +Table v +Created [not included in comparison] +Last Access [not included in comparison] +Type VIEW +View Text SELECT * FROM t +View Default Database default +View Query Output Columns [a, b, c, d] +Properties [view.query.out.col.3=d, view.query.out.col.0=a, view.query.out.numCols=4, view.default.database=default, view.query.out.col.1=b, view.query.out.col.2=c] + + +-- !query 25 +DESC EXTENDED v +-- !query 25 schema +struct +-- !query 25 output +# col_name data_type comment +a string +b int +c string +d string + +# Detailed Table Information +Database default +Table v +Created [not included in comparison] +Last Access [not included in comparison] +Type VIEW +View Text SELECT * FROM t +View Default Database default +View Query Output Columns [a, b, c, d] +Properties [view.query.out.col.3=d, view.query.out.col.0=a, view.query.out.numCols=4, view.default.database=default, view.query.out.col.1=b, view.query.out.col.2=c] + + +-- !query 26 +DESC v PARTITION (c='Us', d=1) +-- !query 26 schema +struct<> +-- !query 26 output +org.apache.spark.sql.AnalysisException +DESC PARTITION is not allowed on a view: v; + + +-- !query 27 DROP TABLE t --- !query 13 schema +-- !query 27 schema struct<> --- !query 13 output +-- !query 27 output + + + +-- !query 28 +DROP VIEW temp_v +-- !query 28 schema +struct<> +-- !query 28 output + + + +-- !query 29 +DROP VIEW temp_Data_Source_View +-- !query 29 schema +struct<> +-- !query 29 output + + + +-- !query 30 +DROP VIEW v +-- !query 30 schema +struct<> +-- !query 30 output diff --git a/sql/core/src/test/resources/sql-tests/results/show-tables.sql.out b/sql/core/src/test/resources/sql-tests/results/show-tables.sql.out index 6d62e6092147..8f2a54f7c24e 100644 --- a/sql/core/src/test/resources/sql-tests/results/show-tables.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/show-tables.sql.out @@ -118,33 +118,40 @@ SHOW TABLE EXTENDED LIKE 'show_t*' -- !query 12 schema struct -- !query 12 output -show_t3 true CatalogTable( - Table: `show_t3` - Created: - Last Access: - Type: VIEW - Schema: [StructField(e,IntegerType,true)] - Storage()) - -showdb show_t1 false CatalogTable( - Table: `showdb`.`show_t1` - Created: - Last Access: - Type: MANAGED - Schema: [StructField(a,StringType,true), StructField(b,IntegerType,true), StructField(c,StringType,true), StructField(d,StringType,true)] - Provider: parquet - Partition Columns: [`c`, `d`] - Storage(Location: sql/core/spark-warehouse/showdb.db/show_t1) - Partition Provider: Catalog) - -showdb show_t2 false CatalogTable( - Table: `showdb`.`show_t2` - Created: - Last Access: - Type: MANAGED - Schema: [StructField(b,StringType,true), StructField(d,IntegerType,true)] - Provider: parquet - Storage(Location: sql/core/spark-warehouse/showdb.db/show_t2)) +show_t3 true Table: show_t3 +Created [not included in comparison] +Last Access [not included in comparison] +Type: VIEW +Schema: root + |-- e: integer (nullable = true) + + +showdb show_t1 false Database: showdb +Table: show_t1 +Created [not included in comparison] +Last Access [not included in comparison] +Type: MANAGED +Provider: parquet +Location [not included in comparison]sql/core/spark-warehouse/showdb.db/show_t1 +Partition Provider: Catalog +Partition Columns: [`c`, `d`] +Schema: root + |-- a: string (nullable = true) + |-- b: integer (nullable = true) + |-- c: string (nullable = true) + |-- d: string (nullable = true) + + +showdb show_t2 false Database: showdb +Table: show_t2 +Created [not included in comparison] +Last Access [not included in comparison] +Type: MANAGED +Provider: parquet +Location [not included in comparison]sql/core/spark-warehouse/showdb.db/show_t2 +Schema: root + |-- b: string (nullable = true) + |-- d: integer (nullable = true) -- !query 13 @@ -166,10 +173,8 @@ SHOW TABLE EXTENDED LIKE 'show_t1' PARTITION(c='Us', d=1) -- !query 14 schema struct -- !query 14 output -showdb show_t1 false CatalogPartition( - Partition Values: [c=Us, d=1] - Storage(Location: sql/core/spark-warehouse/showdb.db/show_t1/c=Us/d=1) - Partition Parameters:{}) +showdb show_t1 false Partition Values: [c=Us, d=1] +Location [not included in comparison]sql/core/spark-warehouse/showdb.db/show_t1/c=Us/d=1 -- !query 15 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala index 4092862c430b..4b69baffab62 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.catalyst.util.{fileToString, stringToFile} +import org.apache.spark.sql.execution.command.DescribeTableCommand import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.StructType @@ -165,8 +166,8 @@ class SQLQueryTestSuite extends QueryTest with SharedSQLContext { s"-- Number of queries: ${outputs.size}\n\n\n" + outputs.zipWithIndex.map{case (qr, i) => qr.toString(i)}.mkString("\n\n\n") + "\n" } - val resultFile = new File(testCase.resultFile); - val parent = resultFile.getParentFile(); + val resultFile = new File(testCase.resultFile) + val parent = resultFile.getParentFile if (!parent.exists()) { assert(parent.mkdirs(), "Could not create directory: " + parent) } @@ -212,23 +213,25 @@ class SQLQueryTestSuite extends QueryTest with SharedSQLContext { /** Executes a query and returns the result as (schema of the output, normalized output). */ private def getNormalizedResult(session: SparkSession, sql: String): (StructType, Seq[String]) = { // Returns true if the plan is supposed to be sorted. - def isSorted(plan: LogicalPlan): Boolean = plan match { + def needSort(plan: LogicalPlan): Boolean = plan match { case _: Join | _: Aggregate | _: Generate | _: Sample | _: Distinct => false + case _: DescribeTableCommand => true case PhysicalOperation(_, _, Sort(_, true, _)) => true - case _ => plan.children.iterator.exists(isSorted) + case _ => plan.children.iterator.exists(needSort) } try { val df = session.sql(sql) val schema = df.schema + val notIncludedMsg = "[not included in comparison]" // Get answer, but also get rid of the #1234 expression ids that show up in explain plans val answer = df.queryExecution.hiveResultString().map(_.replaceAll("#\\d+", "#x") - .replaceAll("Location:.*/sql/core/", "Location: sql/core/") - .replaceAll("Created: .*", "Created: ") - .replaceAll("Last Access: .*", "Last Access: ")) + .replaceAll("Location.*/sql/core/", s"Location ${notIncludedMsg}sql/core/") + .replaceAll("Created.*", s"Created $notIncludedMsg") + .replaceAll("Last Access.*", s"Last Access $notIncludedMsg")) // If the output is not pre-sorted, sort it. - if (isSorted(df.queryExecution.analyzed)) (schema, answer) else (schema, answer.sorted) + if (needSort(df.queryExecution.analyzed)) (schema, answer) else (schema, answer.sorted) } catch { case a: AnalysisException => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala index a4d012cd7611..908b955abbf0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala @@ -224,13 +224,13 @@ class SparkSqlParserSuite extends PlanTest { test("SPARK-17328 Fix NPE with EXPLAIN DESCRIBE TABLE") { assertEqual("describe table t", DescribeTableCommand( - TableIdentifier("t"), Map.empty, isExtended = false, isFormatted = false)) + TableIdentifier("t"), Map.empty, isExtended = false)) assertEqual("describe table extended t", DescribeTableCommand( - TableIdentifier("t"), Map.empty, isExtended = true, isFormatted = false)) + TableIdentifier("t"), Map.empty, isExtended = true)) assertEqual("describe table formatted t", DescribeTableCommand( - TableIdentifier("t"), Map.empty, isExtended = false, isFormatted = true)) + TableIdentifier("t"), Map.empty, isExtended = true)) intercept("explain describe tables x", "Unsupported SQL statement") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index 648b1798c66e..9ebf2dd839a7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -69,18 +69,6 @@ class InMemoryCatalogedDDLSuite extends DDLSuite with SharedSQLContext with Befo tracksPartitionsInCatalog = true) } - test("desc table for parquet data source table using in-memory catalog") { - val tabName = "tab1" - withTable(tabName) { - sql(s"CREATE TABLE $tabName(a int comment 'test') USING parquet ") - - checkAnswer( - sql(s"DESC $tabName").select("col_name", "data_type", "comment"), - Row("a", "int", "test") - ) - } - } - test("alter table: set location (datasource table)") { testSetLocation(isDatasourceTable = true) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 4a02277631f1..5bd36ec25ccb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -806,10 +806,6 @@ class JDBCSuite extends SparkFunSuite sql(s"DESC FORMATTED $tableName").collect().foreach { r => assert(!r.toString().contains(password)) } - - sql(s"DESC EXTENDED $tableName").collect().foreach { r => - assert(!r.toString().contains(password)) - } } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala deleted file mode 100644 index 674463feca4d..000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala +++ /dev/null @@ -1,123 +0,0 @@ -/* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ - -package org.apache.spark.sql.sources - -import org.apache.spark.rdd.RDD -import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String - -class DDLScanSource extends RelationProvider { - override def createRelation( - sqlContext: SQLContext, - parameters: Map[String, String]): BaseRelation = { - SimpleDDLScan( - parameters("from").toInt, - parameters("TO").toInt, - parameters("Table"))(sqlContext.sparkSession) - } -} - -case class SimpleDDLScan( - from: Int, - to: Int, - table: String)(@transient val sparkSession: SparkSession) - extends BaseRelation with TableScan { - - override def sqlContext: SQLContext = sparkSession.sqlContext - - override def schema: StructType = - StructType(Seq( - StructField("intType", IntegerType, nullable = false).withComment(s"test comment $table"), - StructField("stringType", StringType, nullable = false), - StructField("dateType", DateType, nullable = false), - StructField("timestampType", TimestampType, nullable = false), - StructField("doubleType", DoubleType, nullable = false), - StructField("bigintType", LongType, nullable = false), - StructField("tinyintType", ByteType, nullable = false), - StructField("decimalType", DecimalType.USER_DEFAULT, nullable = false), - StructField("fixedDecimalType", DecimalType(5, 1), nullable = false), - StructField("binaryType", BinaryType, nullable = false), - StructField("booleanType", BooleanType, nullable = false), - StructField("smallIntType", ShortType, nullable = false), - StructField("floatType", FloatType, nullable = false), - StructField("mapType", MapType(StringType, StringType)), - StructField("arrayType", ArrayType(StringType)), - StructField("structType", - StructType(StructField("f1", StringType) :: StructField("f2", IntegerType) :: Nil - ) - ) - )) - - override def needConversion: Boolean = false - - override def buildScan(): RDD[Row] = { - // Rely on a type erasure hack to pass RDD[InternalRow] back as RDD[Row] - sparkSession.sparkContext.parallelize(from to to).map { e => - InternalRow(UTF8String.fromString(s"people$e"), e * 2) - }.asInstanceOf[RDD[Row]] - } -} - -class DDLTestSuite extends DataSourceTest with SharedSQLContext { - protected override lazy val sql = spark.sql _ - - override def beforeAll(): Unit = { - super.beforeAll() - sql( - """ - |CREATE OR REPLACE TEMPORARY VIEW ddlPeople - |USING org.apache.spark.sql.sources.DDLScanSource - |OPTIONS ( - | From '1', - | To '10', - | Table 'test1' - |) - """.stripMargin) - } - - sqlTest( - "describe ddlPeople", - Seq( - Row("intType", "int", "test comment test1"), - Row("stringType", "string", null), - Row("dateType", "date", null), - Row("timestampType", "timestamp", null), - Row("doubleType", "double", null), - Row("bigintType", "bigint", null), - Row("tinyintType", "tinyint", null), - Row("decimalType", "decimal(10,0)", null), - Row("fixedDecimalType", "decimal(5,1)", null), - Row("binaryType", "binary", null), - Row("booleanType", "boolean", null), - Row("smallIntType", "smallint", null), - Row("floatType", "float", null), - Row("mapType", "map", null), - Row("arrayType", "array", null), - Row("structType", "struct", null) - )) - - test("SPARK-7686 DescribeCommand should have correct physical plan output attributes") { - val attributes = sql("describe ddlPeople") - .queryExecution.executedPlan.output - assert(attributes.map(_.name) === Seq("col_name", "data_type", "comment")) - assert(attributes.map(_.dataType).toSet === Set(StringType)) - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala index cc77d3c4b91a..80868fff897f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala @@ -17,7 +17,11 @@ package org.apache.spark.sql.sources +import org.apache.spark.rdd.RDD import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String private[sql] abstract class DataSourceTest extends QueryTest { @@ -28,3 +32,55 @@ private[sql] abstract class DataSourceTest extends QueryTest { } } + +class DDLScanSource extends RelationProvider { + override def createRelation( + sqlContext: SQLContext, + parameters: Map[String, String]): BaseRelation = { + SimpleDDLScan( + parameters("from").toInt, + parameters("TO").toInt, + parameters("Table"))(sqlContext.sparkSession) + } +} + +case class SimpleDDLScan( + from: Int, + to: Int, + table: String)(@transient val sparkSession: SparkSession) + extends BaseRelation with TableScan { + + override def sqlContext: SQLContext = sparkSession.sqlContext + + override def schema: StructType = + StructType(Seq( + StructField("intType", IntegerType, nullable = false).withComment(s"test comment $table"), + StructField("stringType", StringType, nullable = false), + StructField("dateType", DateType, nullable = false), + StructField("timestampType", TimestampType, nullable = false), + StructField("doubleType", DoubleType, nullable = false), + StructField("bigintType", LongType, nullable = false), + StructField("tinyintType", ByteType, nullable = false), + StructField("decimalType", DecimalType.USER_DEFAULT, nullable = false), + StructField("fixedDecimalType", DecimalType(5, 1), nullable = false), + StructField("binaryType", BinaryType, nullable = false), + StructField("booleanType", BooleanType, nullable = false), + StructField("smallIntType", ShortType, nullable = false), + StructField("floatType", FloatType, nullable = false), + StructField("mapType", MapType(StringType, StringType)), + StructField("arrayType", ArrayType(StringType)), + StructField("structType", + StructType(StructField("f1", StringType) :: StructField("f2", IntegerType) :: Nil + ) + ) + )) + + override def needConversion: Boolean = false + + override def buildScan(): RDD[Row] = { + // Rely on a type erasure hack to pass RDD[InternalRow] back as RDD[Row] + sparkSession.sparkContext.parallelize(from to to).map { e => + InternalRow(UTF8String.fromString(s"people$e"), e * 2) + }.asInstanceOf[RDD[Row]] + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index 55e02acfa4ce..b55469481557 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -767,9 +767,6 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv sessionState.refreshTable(tableName) val actualSchema = table(tableName).schema assert(schema === actualSchema) - - // Checks the DESCRIBE output. - checkAnswer(sql("DESCRIBE spark6655"), Row("int", "int", null) :: Nil) } } @@ -1381,7 +1378,10 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv checkAnswer(spark.table("old"), Row(1, "a")) - checkAnswer(sql("DESC old"), Row("i", "int", null) :: Row("j", "string", null) :: Nil) + val expectedSchema = StructType(Seq( + StructField("i", IntegerType, nullable = true), + StructField("j", StringType, nullable = true))) + assert(table("old").schema === expectedSchema) } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala index 536ca8fd9d45..e45cf977bfaa 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala @@ -207,6 +207,7 @@ abstract class HiveComparisonTest // This list contains indicators for those lines which do not have actual results and we // want to ignore. lazy val ignoredLineIndicators = Seq( + "# Detailed Table Information", "# Partition Information", "# col_name" ) @@ -358,7 +359,7 @@ abstract class HiveComparisonTest stringToFile(new File(failedDirectory, testCaseName), errorMessage + consoleTestCase) fail(errorMessage) } - }.toSeq + } (queryList, hiveResults, catalystResults).zipped.foreach { case (query, hive, (hiveQuery, catalyst)) => @@ -369,6 +370,7 @@ abstract class HiveComparisonTest if ((!hiveQuery.logical.isInstanceOf[ExplainCommand]) && (!hiveQuery.logical.isInstanceOf[ShowFunctionsCommand]) && (!hiveQuery.logical.isInstanceOf[DescribeFunctionCommand]) && + (!hiveQuery.logical.isInstanceOf[DescribeTableCommand]) && preparedHive != catalyst) { val hivePrintOut = s"== HIVE - ${preparedHive.size} row(s) ==" +: preparedHive diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala index f0a995c274b6..3906968aaff1 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -708,23 +708,6 @@ class HiveDDLSuite } } - test("desc table for Hive table") { - withTable("tab1") { - val tabName = "tab1" - sql(s"CREATE TABLE $tabName(c1 int)") - - assert(sql(s"DESC $tabName").collect().length == 1) - - assert( - sql(s"DESC FORMATTED $tabName").collect() - .exists(_.getString(0) == "# Storage Information")) - - assert( - sql(s"DESC EXTENDED $tabName").collect() - .exists(_.getString(0) == "# Detailed Table Information")) - } - } - test("desc table for Hive table - partitioned table") { withTable("tbl") { sql("CREATE TABLE tbl(a int) PARTITIONED BY (b int)") @@ -741,23 +724,6 @@ class HiveDDLSuite } } - test("desc formatted table for permanent view") { - withTable("tbl") { - withView("view1") { - sql("CREATE TABLE tbl(a int)") - sql("CREATE VIEW view1 AS SELECT * FROM tbl") - assert(sql("DESC FORMATTED view1").collect().containsSlice( - Seq( - Row("# View Information", "", ""), - Row("View Text:", "SELECT * FROM tbl", ""), - Row("View Default Database:", "default", ""), - Row("View Query Output Columns:", "[a]", "") - ) - )) - } - } - } - test("desc table for data source table using Hive Metastore") { assume(spark.sparkContext.conf.get(CATALOG_IMPLEMENTATION) == "hive") val tabName = "tab1" @@ -766,7 +732,7 @@ class HiveDDLSuite checkAnswer( sql(s"DESC $tabName").select("col_name", "data_type", "comment"), - Row("a", "int", "test") + Row("# col_name", "data_type", "comment") :: Row("a", "int", "test") :: Nil ) } } @@ -1218,23 +1184,6 @@ class HiveDDLSuite sql(s"SELECT * FROM ${targetTable.identifier}")) } - test("desc table for data source table") { - withTable("tab1") { - val tabName = "tab1" - spark.range(1).write.format("json").saveAsTable(tabName) - - assert(sql(s"DESC $tabName").collect().length == 1) - - assert( - sql(s"DESC FORMATTED $tabName").collect() - .exists(_.getString(0) == "# Storage Information")) - - assert( - sql(s"DESC EXTENDED $tabName").collect() - .exists(_.getString(0) == "# Detailed Table Information")) - } - } - test("create table with the same name as an index table") { val tabName = "tab1" val indexName = tabName + "_index" @@ -1320,46 +1269,6 @@ class HiveDDLSuite } } - test("desc table for data source table - partitioned bucketed table") { - withTable("t1") { - spark - .range(1).select('id as 'a, 'id as 'b, 'id as 'c, 'id as 'd).write - .bucketBy(2, "b").sortBy("c").partitionBy("d") - .saveAsTable("t1") - - val formattedDesc = sql("DESC FORMATTED t1").collect() - - assert(formattedDesc.containsSlice( - Seq( - Row("a", "bigint", null), - Row("b", "bigint", null), - Row("c", "bigint", null), - Row("d", "bigint", null), - Row("# Partition Information", "", ""), - Row("# col_name", "data_type", "comment"), - Row("d", "bigint", null), - Row("", "", ""), - Row("# Detailed Table Information", "", ""), - Row("Database:", "default", "") - ) - )) - - assert(formattedDesc.containsSlice( - Seq( - Row("Table Type:", "MANAGED", "") - ) - )) - - assert(formattedDesc.containsSlice( - Seq( - Row("Num Buckets:", "2", ""), - Row("Bucket Columns:", "[b]", ""), - Row("Sort Columns:", "[c]", "") - ) - )) - } - } - test("datasource and statistics table property keys are not allowed") { import org.apache.spark.sql.hive.HiveExternalCatalog.DATASOURCE_PREFIX import org.apache.spark.sql.hive.HiveExternalCatalog.STATISTICS_PREFIX diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveOperatorQueryableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveOperatorQueryableSuite.scala deleted file mode 100644 index 0e89e990e564..000000000000 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveOperatorQueryableSuite.scala +++ /dev/null @@ -1,53 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.hive.execution - -import org.apache.spark.sql.{QueryTest, Row} -import org.apache.spark.sql.hive.test.TestHiveSingleton - -/** - * A set of tests that validates commands can also be queried by like a table - */ -class HiveOperatorQueryableSuite extends QueryTest with TestHiveSingleton { - import spark._ - - test("SPARK-5324 query result of describe command") { - hiveContext.loadTestTable("src") - - // Creates a temporary view with the output of a describe command - sql("desc src").createOrReplaceTempView("mydesc") - checkAnswer( - sql("desc mydesc"), - Seq( - Row("col_name", "string", "name of the column"), - Row("data_type", "string", "data type of the column"), - Row("comment", "string", "comment of the column"))) - - checkAnswer( - sql("select * from mydesc"), - Seq( - Row("key", "int", null), - Row("value", "string", null))) - - checkAnswer( - sql("select col_name, data_type, comment from mydesc"), - Seq( - Row("key", "int", null), - Row("value", "string", null))) - } -} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index dd278f683a3c..65a902fc5438 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -789,62 +789,6 @@ class HiveQuerySuite extends HiveComparisonTest with SQLTestUtils with BeforeAnd assert(Try(q0.count()).isSuccess) } - test("DESCRIBE commands") { - sql(s"CREATE TABLE test_describe_commands1 (key INT, value STRING) PARTITIONED BY (dt STRING)") - - sql( - """FROM src INSERT OVERWRITE TABLE test_describe_commands1 PARTITION (dt='2008-06-08') - |SELECT key, value - """.stripMargin) - - // Describe a table - assertResult( - Array( - Row("key", "int", null), - Row("value", "string", null), - Row("dt", "string", null), - Row("# Partition Information", "", ""), - Row("# col_name", "data_type", "comment"), - Row("dt", "string", null)) - ) { - sql("DESCRIBE test_describe_commands1") - .select('col_name, 'data_type, 'comment) - .collect() - } - - // Describe a table with a fully qualified table name - assertResult( - Array( - Row("key", "int", null), - Row("value", "string", null), - Row("dt", "string", null), - Row("# Partition Information", "", ""), - Row("# col_name", "data_type", "comment"), - Row("dt", "string", null)) - ) { - sql("DESCRIBE default.test_describe_commands1") - .select('col_name, 'data_type, 'comment) - .collect() - } - - // Describe a temporary view. - val testData = - TestHive.sparkContext.parallelize( - TestData(1, "str1") :: - TestData(1, "str2") :: Nil) - testData.toDF().createOrReplaceTempView("test_describe_commands2") - - assertResult( - Array( - Row("a", "int", null), - Row("b", "string", null)) - ) { - sql("DESCRIBE test_describe_commands2") - .select('col_name, 'data_type, 'comment) - .collect() - } - } - test("SPARK-2263: Insert Map values") { sql("CREATE TABLE m(value MAP)") sql("INSERT OVERWRITE TABLE m SELECT MAP(key, value) FROM src LIMIT 10") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 55ff4bb115e5..d012797e1992 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -363,79 +363,6 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } } - test("describe partition") { - withTable("partitioned_table") { - sql("CREATE TABLE partitioned_table (a STRING, b INT) PARTITIONED BY (c STRING, d STRING)") - sql("ALTER TABLE partitioned_table ADD PARTITION (c='Us', d=1)") - - checkKeywordsExist(sql("DESC partitioned_table PARTITION (c='Us', d=1)"), - "# Partition Information", - "# col_name") - - checkKeywordsExist(sql("DESC EXTENDED partitioned_table PARTITION (c='Us', d=1)"), - "# Partition Information", - "# col_name", - "Detailed Partition Information CatalogPartition(", - "Partition Values: [c=Us, d=1]", - "Storage(Location:", - "Partition Parameters") - - checkKeywordsExist(sql("DESC FORMATTED partitioned_table PARTITION (c='Us', d=1)"), - "# Partition Information", - "# col_name", - "# Detailed Partition Information", - "Partition Value:", - "Database:", - "Table:", - "Location:", - "Partition Parameters:", - "# Storage Information") - } - } - - test("describe partition - error handling") { - withTable("partitioned_table", "datasource_table") { - sql("CREATE TABLE partitioned_table (a STRING, b INT) PARTITIONED BY (c STRING, d STRING)") - sql("ALTER TABLE partitioned_table ADD PARTITION (c='Us', d=1)") - - val m = intercept[NoSuchPartitionException] { - sql("DESC partitioned_table PARTITION (c='Us', d=2)") - }.getMessage() - assert(m.contains("Partition not found in table")) - - val m2 = intercept[AnalysisException] { - sql("DESC partitioned_table PARTITION (c='Us')") - }.getMessage() - assert(m2.contains("Partition spec is invalid")) - - val m3 = intercept[ParseException] { - sql("DESC partitioned_table PARTITION (c='Us', d)") - }.getMessage() - assert(m3.contains("PARTITION specification is incomplete: `d`")) - - spark - .range(1).select('id as 'a, 'id as 'b, 'id as 'c, 'id as 'd).write - .partitionBy("d") - .saveAsTable("datasource_table") - - sql("DESC datasource_table PARTITION (d=0)") - - val m5 = intercept[AnalysisException] { - spark.range(10).select('id as 'a, 'id as 'b).createTempView("view1") - sql("DESC view1 PARTITION (c='Us', d=1)") - }.getMessage() - assert(m5.contains("DESC PARTITION is not allowed on a temporary view")) - - withView("permanent_view") { - val m = intercept[AnalysisException] { - sql("CREATE VIEW permanent_view AS SELECT * FROM partitioned_table") - sql("DESC permanent_view PARTITION (c='Us', d=1)") - }.getMessage() - assert(m.contains("DESC PARTITION is not allowed on a view")) - } - } - } - test("SPARK-5371: union with null and sum") { val df = Seq((1, 1)).toDF("c1", "c2") df.createOrReplaceTempView("table1") @@ -676,7 +603,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } test("CTAS with serde") { - sql("CREATE TABLE ctas1 AS SELECT key k, value FROM src ORDER BY k, value").collect() + sql("CREATE TABLE ctas1 AS SELECT key k, value FROM src ORDER BY k, value") sql( """CREATE TABLE ctas2 | ROW FORMAT SERDE "org.apache.hadoop.hive.serde2.columnar.ColumnarSerDe" @@ -686,86 +613,76 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { | AS | SELECT key, value | FROM src - | ORDER BY key, value""".stripMargin).collect() + | ORDER BY key, value""".stripMargin) + + val storageCtas2 = spark.sessionState.catalog.getTableMetadata(TableIdentifier("ctas2")).storage + assert(storageCtas2.inputFormat == Some("org.apache.hadoop.hive.ql.io.RCFileInputFormat")) + assert(storageCtas2.outputFormat == Some("org.apache.hadoop.hive.ql.io.RCFileOutputFormat")) + assert(storageCtas2.serde == Some("org.apache.hadoop.hive.serde2.columnar.ColumnarSerDe")) + sql( """CREATE TABLE ctas3 | ROW FORMAT DELIMITED FIELDS TERMINATED BY ',' LINES TERMINATED BY '\012' | STORED AS textfile AS | SELECT key, value | FROM src - | ORDER BY key, value""".stripMargin).collect() + | ORDER BY key, value""".stripMargin) // the table schema may like (key: integer, value: string) sql( """CREATE TABLE IF NOT EXISTS ctas4 AS - | SELECT 1 AS key, value FROM src LIMIT 1""".stripMargin).collect() + | SELECT 1 AS key, value FROM src LIMIT 1""".stripMargin) // do nothing cause the table ctas4 already existed. sql( """CREATE TABLE IF NOT EXISTS ctas4 AS - | SELECT key, value FROM src ORDER BY key, value""".stripMargin).collect() + | SELECT key, value FROM src ORDER BY key, value""".stripMargin) checkAnswer( sql("SELECT k, value FROM ctas1 ORDER BY k, value"), - sql("SELECT key, value FROM src ORDER BY key, value").collect().toSeq) + sql("SELECT key, value FROM src ORDER BY key, value")) checkAnswer( sql("SELECT key, value FROM ctas2 ORDER BY key, value"), sql( """ SELECT key, value FROM src - ORDER BY key, value""").collect().toSeq) + ORDER BY key, value""")) checkAnswer( sql("SELECT key, value FROM ctas3 ORDER BY key, value"), sql( """ SELECT key, value FROM src - ORDER BY key, value""").collect().toSeq) + ORDER BY key, value""")) intercept[AnalysisException] { sql( """CREATE TABLE ctas4 AS - | SELECT key, value FROM src ORDER BY key, value""".stripMargin).collect() + | SELECT key, value FROM src ORDER BY key, value""".stripMargin) } checkAnswer( sql("SELECT key, value FROM ctas4 ORDER BY key, value"), sql("SELECT key, value FROM ctas4 LIMIT 1").collect().toSeq) - /* - Disabled because our describe table does not output the serde information right now. - checkKeywordsExist(sql("DESC EXTENDED ctas2"), - "name:key", "type:string", "name:value", "ctas2", - "org.apache.hadoop.hive.ql.io.RCFileInputFormat", - "org.apache.hadoop.hive.ql.io.RCFileOutputFormat", - "org.apache.hadoop.hive.serde2.columnar.ColumnarSerDe", - "serde_p1=p1", "serde_p2=p2", "tbl_p1=p11", "tbl_p2=p22", "MANAGED_TABLE" - ) - */ - sql( """CREATE TABLE ctas5 | STORED AS parquet AS | SELECT key, value | FROM src - | ORDER BY key, value""".stripMargin).collect() + | ORDER BY key, value""".stripMargin) + val storageCtas5 = spark.sessionState.catalog.getTableMetadata(TableIdentifier("ctas5")).storage + assert(storageCtas5.inputFormat == + Some("org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat")) + assert(storageCtas5.outputFormat == + Some("org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat")) + assert(storageCtas5.serde == + Some("org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe")) - /* - Disabled because our describe table does not output the serde information right now. - withSQLConf(HiveUtils.CONVERT_METASTORE_PARQUET.key -> "false") { - checkKeywordsExist(sql("DESC EXTENDED ctas5"), - "name:key", "type:string", "name:value", "ctas5", - "org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat", - "org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat", - "org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe", - "MANAGED_TABLE" - ) - } - */ // use the Hive SerDe for parquet tables withSQLConf(HiveUtils.CONVERT_METASTORE_PARQUET.key -> "false") { checkAnswer( sql("SELECT key, value FROM ctas5 ORDER BY key, value"), - sql("SELECT key, value FROM src ORDER BY key, value").collect().toSeq) + sql("SELECT key, value FROM src ORDER BY key, value")) } } From b34f7665ddb0a40044b4c2bc7d351599c125cb13 Mon Sep 17 00:00:00 2001 From: zero323 Date: Mon, 3 Apr 2017 23:42:04 -0700 Subject: [PATCH 201/512] [SPARK-19825][R][ML] spark.ml R API for FPGrowth ## What changes were proposed in this pull request? Adds SparkR API for FPGrowth: [SPARK-19825](https://issues.apache.org/jira/browse/SPARK-19825): - `spark.fpGrowth` -model training. - `freqItemsets` and `associationRules` methods with new corresponding generics. - Scala helper: `org.apache.spark.ml.r. FPGrowthWrapper` - unit tests. ## How was this patch tested? Feature specific unit tests. Author: zero323 Closes #17170 from zero323/SPARK-19825. --- R/pkg/DESCRIPTION | 1 + R/pkg/NAMESPACE | 5 +- R/pkg/R/generics.R | 12 ++ R/pkg/R/mllib_fpm.R | 158 ++++++++++++++++++ R/pkg/R/mllib_utils.R | 2 + R/pkg/inst/tests/testthat/test_mllib_fpm.R | 83 +++++++++ .../apache/spark/ml/r/FPGrowthWrapper.scala | 86 ++++++++++ .../org/apache/spark/ml/r/RWrappers.scala | 2 + 8 files changed, 348 insertions(+), 1 deletion(-) create mode 100644 R/pkg/R/mllib_fpm.R create mode 100644 R/pkg/inst/tests/testthat/test_mllib_fpm.R create mode 100644 mllib/src/main/scala/org/apache/spark/ml/r/FPGrowthWrapper.scala diff --git a/R/pkg/DESCRIPTION b/R/pkg/DESCRIPTION index 00dde64324ae..f475ee87702e 100644 --- a/R/pkg/DESCRIPTION +++ b/R/pkg/DESCRIPTION @@ -44,6 +44,7 @@ Collate: 'jvm.R' 'mllib_classification.R' 'mllib_clustering.R' + 'mllib_fpm.R' 'mllib_recommendation.R' 'mllib_regression.R' 'mllib_stat.R' diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index c02046c94bf4..9b7e95ce30ac 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -66,7 +66,10 @@ exportMethods("glm", "spark.randomForest", "spark.gbt", "spark.bisectingKmeans", - "spark.svmLinear") + "spark.svmLinear", + "spark.fpGrowth", + "spark.freqItemsets", + "spark.associationRules") # Job group lifecycle management methods export("setJobGroup", diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 80283e48ced7..945676c7f10b 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -1445,6 +1445,18 @@ setGeneric("spark.posterior", function(object, newData) { standardGeneric("spark #' @export setGeneric("spark.perplexity", function(object, data) { standardGeneric("spark.perplexity") }) +#' @rdname spark.fpGrowth +#' @export +setGeneric("spark.fpGrowth", function(data, ...) { standardGeneric("spark.fpGrowth") }) + +#' @rdname spark.fpGrowth +#' @export +setGeneric("spark.freqItemsets", function(object) { standardGeneric("spark.freqItemsets") }) + +#' @rdname spark.fpGrowth +#' @export +setGeneric("spark.associationRules", function(object) { standardGeneric("spark.associationRules") }) + #' @param object a fitted ML model object. #' @param path the directory where the model is saved. #' @param ... additional argument(s) passed to the method. diff --git a/R/pkg/R/mllib_fpm.R b/R/pkg/R/mllib_fpm.R new file mode 100644 index 000000000000..96251b2c7c19 --- /dev/null +++ b/R/pkg/R/mllib_fpm.R @@ -0,0 +1,158 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# mllib_fpm.R: Provides methods for MLlib frequent pattern mining algorithms integration + +#' S4 class that represents a FPGrowthModel +#' +#' @param jobj a Java object reference to the backing Scala FPGrowthModel +#' @export +#' @note FPGrowthModel since 2.2.0 +setClass("FPGrowthModel", slots = list(jobj = "jobj")) + +#' FP-growth +#' +#' A parallel FP-growth algorithm to mine frequent itemsets. +#' For more details, see +#' \href{https://spark.apache.org/docs/latest/mllib-frequent-pattern-mining.html#fp-growth}{ +#' FP-growth}. +#' +#' @param data A SparkDataFrame for training. +#' @param minSupport Minimal support level. +#' @param minConfidence Minimal confidence level. +#' @param itemsCol Features column name. +#' @param numPartitions Number of partitions used for fitting. +#' @param ... additional argument(s) passed to the method. +#' @return \code{spark.fpGrowth} returns a fitted FPGrowth model. +#' @rdname spark.fpGrowth +#' @name spark.fpGrowth +#' @aliases spark.fpGrowth,SparkDataFrame-method +#' @export +#' @examples +#' \dontrun{ +#' raw_data <- read.df( +#' "data/mllib/sample_fpgrowth.txt", +#' source = "csv", +#' schema = structType(structField("raw_items", "string"))) +#' +#' data <- selectExpr(raw_data, "split(raw_items, ' ') as items") +#' model <- spark.fpGrowth(data) +#' +#' # Show frequent itemsets +#' frequent_itemsets <- spark.freqItemsets(model) +#' showDF(frequent_itemsets) +#' +#' # Show association rules +#' association_rules <- spark.associationRules(model) +#' showDF(association_rules) +#' +#' # Predict on new data +#' new_itemsets <- data.frame(items = c("t", "t,s")) +#' new_data <- selectExpr(createDataFrame(new_itemsets), "split(items, ',') as items") +#' predict(model, new_data) +#' +#' # Save and load model +#' path <- "/path/to/model" +#' write.ml(model, path) +#' read.ml(path) +#' +#' # Optional arguments +#' baskets_data <- selectExpr(createDataFrame(itemsets), "split(items, ',') as baskets") +#' another_model <- spark.fpGrowth(data, minSupport = 0.1, minConfidence = 0.5, +#' itemsCol = "baskets", numPartitions = 10) +#' } +#' @note spark.fpGrowth since 2.2.0 +setMethod("spark.fpGrowth", signature(data = "SparkDataFrame"), + function(data, minSupport = 0.3, minConfidence = 0.8, + itemsCol = "items", numPartitions = NULL) { + if (!is.numeric(minSupport) || minSupport < 0 || minSupport > 1) { + stop("minSupport should be a number [0, 1].") + } + if (!is.numeric(minConfidence) || minConfidence < 0 || minConfidence > 1) { + stop("minConfidence should be a number [0, 1].") + } + if (!is.null(numPartitions)) { + numPartitions <- as.integer(numPartitions) + stopifnot(numPartitions > 0) + } + + jobj <- callJStatic("org.apache.spark.ml.r.FPGrowthWrapper", "fit", + data@sdf, as.numeric(minSupport), as.numeric(minConfidence), + itemsCol, numPartitions) + new("FPGrowthModel", jobj = jobj) + }) + +# Get frequent itemsets. + +#' @param object a fitted FPGrowth model. +#' @return A \code{SparkDataFrame} with frequent itemsets. +#' The \code{SparkDataFrame} contains two columns: +#' \code{items} (an array of the same type as the input column) +#' and \code{freq} (frequency of the itemset). +#' @rdname spark.fpGrowth +#' @aliases freqItemsets,FPGrowthModel-method +#' @export +#' @note spark.freqItemsets(FPGrowthModel) since 2.2.0 +setMethod("spark.freqItemsets", signature(object = "FPGrowthModel"), + function(object) { + dataFrame(callJMethod(object@jobj, "freqItemsets")) + }) + +# Get association rules. + +#' @return A \code{SparkDataFrame} with association rules. +#' The \code{SparkDataFrame} contains three columns: +#' \code{antecedent} (an array of the same type as the input column), +#' \code{consequent} (an array of the same type as the input column), +#' and \code{condfidence} (confidence). +#' @rdname spark.fpGrowth +#' @aliases associationRules,FPGrowthModel-method +#' @export +#' @note spark.associationRules(FPGrowthModel) since 2.2.0 +setMethod("spark.associationRules", signature(object = "FPGrowthModel"), + function(object) { + dataFrame(callJMethod(object@jobj, "associationRules")) + }) + +# Makes predictions based on generated association rules + +#' @param newData a SparkDataFrame for testing. +#' @return \code{predict} returns a SparkDataFrame containing predicted values. +#' @rdname spark.fpGrowth +#' @aliases predict,FPGrowthModel-method +#' @export +#' @note predict(FPGrowthModel) since 2.2.0 +setMethod("predict", signature(object = "FPGrowthModel"), + function(object, newData) { + predict_internal(object, newData) + }) + +# Saves the FPGrowth model to the output path. + +#' @param path the directory where the model is saved. +#' @param overwrite logical value indicating whether to overwrite if the output path +#' already exists. Default is FALSE which means throw exception +#' if the output path exists. +#' @rdname spark.fpGrowth +#' @aliases write.ml,FPGrowthModel,character-method +#' @export +#' @seealso \link{read.ml} +#' @note write.ml(FPGrowthModel, character) since 2.2.0 +setMethod("write.ml", signature(object = "FPGrowthModel", path = "character"), + function(object, path, overwrite = FALSE) { + write_internal(object, path, overwrite) + }) diff --git a/R/pkg/R/mllib_utils.R b/R/pkg/R/mllib_utils.R index 04a0a6f94441..5dfef8625061 100644 --- a/R/pkg/R/mllib_utils.R +++ b/R/pkg/R/mllib_utils.R @@ -118,6 +118,8 @@ read.ml <- function(path) { new("BisectingKMeansModel", jobj = jobj) } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.LinearSVCWrapper")) { new("LinearSVCModel", jobj = jobj) + } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.FPGrowthWrapper")) { + new("FPGrowthModel", jobj = jobj) } else { stop("Unsupported model: ", jobj) } diff --git a/R/pkg/inst/tests/testthat/test_mllib_fpm.R b/R/pkg/inst/tests/testthat/test_mllib_fpm.R new file mode 100644 index 000000000000..c38f1133897d --- /dev/null +++ b/R/pkg/inst/tests/testthat/test_mllib_fpm.R @@ -0,0 +1,83 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +library(testthat) + +context("MLlib frequent pattern mining") + +# Tests for MLlib frequent pattern mining algorithms in SparkR +sparkSession <- sparkR.session(enableHiveSupport = FALSE) + +test_that("spark.fpGrowth", { + data <- selectExpr(createDataFrame(data.frame(items = c( + "1,2", + "1,2", + "1,2,3", + "1,3" + ))), "split(items, ',') as items") + + model <- spark.fpGrowth(data, minSupport = 0.3, minConfidence = 0.8, numPartitions = 1) + + itemsets <- collect(spark.freqItemsets(model)) + + expected_itemsets <- data.frame( + items = I(list(list("3"), list("3", "1"), list("2"), list("2", "1"), list("1"))), + freq = c(2, 2, 3, 3, 4) + ) + + expect_equivalent(expected_itemsets, itemsets) + + expected_association_rules <- data.frame( + antecedent = I(list(list("2"), list("3"))), + consequent = I(list(list("1"), list("1"))), + confidence = c(1, 1) + ) + + expect_equivalent(expected_association_rules, collect(spark.associationRules(model))) + + new_data <- selectExpr(createDataFrame(data.frame(items = c( + "1,2", + "1,3", + "2,3" + ))), "split(items, ',') as items") + + expected_predictions <- data.frame( + items = I(list(list("1", "2"), list("1", "3"), list("2", "3"))), + prediction = I(list(list(), list(), list("1"))) + ) + + expect_equivalent(expected_predictions, collect(predict(model, new_data))) + + modelPath <- tempfile(pattern = "spark-fpm", fileext = ".tmp") + write.ml(model, modelPath, overwrite = TRUE) + loaded_model <- read.ml(modelPath) + + expect_equivalent( + itemsets, + collect(spark.freqItemsets(loaded_model))) + + unlink(modelPath) + + model_without_numpartitions <- spark.fpGrowth(data, minSupport = 0.3, minConfidence = 0.8) + expect_equal( + count(spark.freqItemsets(model_without_numpartitions)), + count(spark.freqItemsets(model)) + ) + +}) + +sparkR.session.stop() diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/FPGrowthWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/FPGrowthWrapper.scala new file mode 100644 index 000000000000..b8151d8d9070 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/r/FPGrowthWrapper.scala @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.r + +import org.apache.hadoop.fs.Path +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods._ + +import org.apache.spark.ml.fpm.{FPGrowth, FPGrowthModel} +import org.apache.spark.ml.util._ +import org.apache.spark.sql.{DataFrame, Dataset} + +private[r] class FPGrowthWrapper private (val fpGrowthModel: FPGrowthModel) extends MLWritable { + def freqItemsets: DataFrame = fpGrowthModel.freqItemsets + def associationRules: DataFrame = fpGrowthModel.associationRules + + def transform(dataset: Dataset[_]): DataFrame = { + fpGrowthModel.transform(dataset) + } + + override def write: MLWriter = new FPGrowthWrapper.FPGrowthWrapperWriter(this) +} + +private[r] object FPGrowthWrapper extends MLReadable[FPGrowthWrapper] { + + def fit( + data: DataFrame, + minSupport: Double, + minConfidence: Double, + itemsCol: String, + numPartitions: Integer): FPGrowthWrapper = { + val fpGrowth = new FPGrowth() + .setMinSupport(minSupport) + .setMinConfidence(minConfidence) + .setItemsCol(itemsCol) + + if (numPartitions != null && numPartitions > 0) { + fpGrowth.setNumPartitions(numPartitions) + } + + val fpGrowthModel = fpGrowth.fit(data) + + new FPGrowthWrapper(fpGrowthModel) + } + + override def read: MLReader[FPGrowthWrapper] = new FPGrowthWrapperReader + + class FPGrowthWrapperReader extends MLReader[FPGrowthWrapper] { + override def load(path: String): FPGrowthWrapper = { + val modelPath = new Path(path, "model").toString + val fPGrowthModel = FPGrowthModel.load(modelPath) + + new FPGrowthWrapper(fPGrowthModel) + } + } + + class FPGrowthWrapperWriter(instance: FPGrowthWrapper) extends MLWriter { + override protected def saveImpl(path: String): Unit = { + val modelPath = new Path(path, "model").toString + val rMetadataPath = new Path(path, "rMetadata").toString + + val rMetadataJson: String = compact(render( + "class" -> instance.getClass.getName + )) + + sc.parallelize(Seq(rMetadataJson), 1).saveAsTextFile(rMetadataPath) + + instance.fpGrowthModel.save(modelPath) + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala b/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala index 358e522dfe1c..b30ce12bc6cc 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala @@ -68,6 +68,8 @@ private[r] object RWrappers extends MLReader[Object] { BisectingKMeansWrapper.load(path) case "org.apache.spark.ml.r.LinearSVCWrapper" => LinearSVCWrapper.load(path) + case "org.apache.spark.ml.r.FPGrowthWrapper" => + FPGrowthWrapper.load(path) case _ => throw new SparkException(s"SparkR read.ml does not support load $className") } From c95fbea68e9dfb2c96a1d13dde17d80a37066ae6 Mon Sep 17 00:00:00 2001 From: guoxiaolongzte Date: Tue, 4 Apr 2017 09:56:17 +0100 Subject: [PATCH 202/512] =?UTF-8?q?[SPARK-20190][APP-ID]=20applications//j?= =?UTF-8?q?obs'=20in=20rest=20api,status=20should=20be=20[running|s?= =?UTF-8?q?=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …ucceeded|failed|unknown] ## What changes were proposed in this pull request? '/applications/[app-id]/jobs' in rest api.status should be'[running|succeeded|failed|unknown]'. now status is '[complete|succeeded|failed]'. but '/applications/[app-id]/jobs?status=complete' the server return 'HTTP ERROR 404'. Added '?status=running' and '?status=unknown'. code : public enum JobExecutionStatus { RUNNING, SUCCEEDED, FAILED, UNKNOWN; ## How was this patch tested? manual tests Please review http://spark.apache.org/contributing.html before opening a pull request. Author: guoxiaolongzte Closes #17507 from guoxiaolongzte/SPARK-20190. --- docs/monitoring.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/monitoring.md b/docs/monitoring.md index 6cbc6660e816..4d0617d253b8 100644 --- a/docs/monitoring.md +++ b/docs/monitoring.md @@ -289,7 +289,7 @@ can be identified by their `[attempt-id]`. In the API listed below, when running
    From 26e7bca2295faeef22b2d9554f316c97bc240fd7 Mon Sep 17 00:00:00 2001 From: Xiao Li Date: Tue, 4 Apr 2017 18:57:46 +0800 Subject: [PATCH 203/512] [SPARK-20198][SQL] Remove the inconsistency in table/function name conventions in SparkSession.Catalog APIs ### What changes were proposed in this pull request? Observed by felixcheung , in `SparkSession`.`Catalog` APIs, we have different conventions/rules for table/function identifiers/names. Most APIs accept the qualified name (i.e., `databaseName`.`tableName` or `databaseName`.`functionName`). However, the following five APIs do not accept it. - def listColumns(tableName: String): Dataset[Column] - def getTable(tableName: String): Table - def getFunction(functionName: String): Function - def tableExists(tableName: String): Boolean - def functionExists(functionName: String): Boolean To make them consistent with the other Catalog APIs, this PR does the changes, updates the function/API comments and adds the `params` to clarify the inputs we allow. ### How was this patch tested? Added the test cases . Author: Xiao Li Closes #17518 from gatorsmile/tableIdentifier. --- .../spark/sql/catalyst/parser/SqlBase.g4 | 8 ++ .../sql/catalyst/parser/AstBuilder.scala | 13 +++ .../sql/catalyst/parser/ParseDriver.scala | 7 +- .../sql/catalyst/parser/ParserInterface.scala | 5 +- .../org/apache/spark/sql/SparkSession.scala | 7 +- .../apache/spark/sql/catalog/Catalog.scala | 109 +++++++++++++++--- .../spark/sql/internal/CatalogImpl.scala | 73 ++++++------ .../spark/sql/internal/CatalogSuite.scala | 21 ++++ 8 files changed, 186 insertions(+), 57 deletions(-) diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index c4a590ec6916..52b5b347fa9c 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -56,6 +56,10 @@ singleTableIdentifier : tableIdentifier EOF ; +singleFunctionIdentifier + : functionIdentifier EOF + ; + singleDataType : dataType EOF ; @@ -493,6 +497,10 @@ tableIdentifier : (db=identifier '.')? table=identifier ; +functionIdentifier + : (db=identifier '.')? function=identifier + ; + namedExpression : expression (AS? (identifier | identifierList))? ; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 162051a8c0e4..fab7e4c5b128 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -75,6 +75,11 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { visitTableIdentifier(ctx.tableIdentifier) } + override def visitSingleFunctionIdentifier( + ctx: SingleFunctionIdentifierContext): FunctionIdentifier = withOrigin(ctx) { + visitFunctionIdentifier(ctx.functionIdentifier) + } + override def visitSingleDataType(ctx: SingleDataTypeContext): DataType = withOrigin(ctx) { visitSparkDataType(ctx.dataType) } @@ -759,6 +764,14 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { TableIdentifier(ctx.table.getText, Option(ctx.db).map(_.getText)) } + /** + * Create a [[FunctionIdentifier]] from a 'functionName' or 'databaseName'.'functionName' pattern. + */ + override def visitFunctionIdentifier( + ctx: FunctionIdentifierContext): FunctionIdentifier = withOrigin(ctx) { + FunctionIdentifier(ctx.function.getText, Option(ctx.db).map(_.getText)) + } + /* ******************************************************************************************** * Expression parsing * ******************************************************************************************** */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala index f704b0998cad..80ab75cc17fa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala @@ -22,7 +22,7 @@ import org.antlr.v4.runtime.misc.ParseCancellationException import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.trees.Origin @@ -49,6 +49,11 @@ abstract class AbstractSqlParser extends ParserInterface with Logging { astBuilder.visitSingleTableIdentifier(parser.singleTableIdentifier()) } + /** Creates FunctionIdentifier for a given SQL string. */ + def parseFunctionIdentifier(sqlText: String): FunctionIdentifier = parse(sqlText) { parser => + astBuilder.visitSingleFunctionIdentifier(parser.singleFunctionIdentifier()) + } + /** * Creates StructType for a given SQL string, which is a comma separated list of field * definitions which will preserve the correct Hive metadata. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserInterface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserInterface.scala index 6edbe253970e..db3598bde04d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserInterface.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserInterface.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.parser -import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.types.StructType @@ -35,6 +35,9 @@ trait ParserInterface { /** Creates TableIdentifier for a given SQL string. */ def parseTableIdentifier(sqlText: String): TableIdentifier + /** Creates FunctionIdentifier for a given SQL string. */ + def parseFunctionIdentifier(sqlText: String): FunctionIdentifier + /** * Creates StructType for a given SQL string, which is a comma separated list of field * definitions which will preserve the correct Hive metadata. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index b60499253c42..95f3463dfe62 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -591,8 +591,13 @@ class SparkSession private( @transient lazy val catalog: Catalog = new CatalogImpl(self) /** - * Returns the specified table as a `DataFrame`. + * Returns the specified table/view as a `DataFrame`. * + * @param tableName is either a qualified or unqualified name that designates a table or view. + * If a database is specified, it identifies the table/view from the database. + * Otherwise, it first attempts to find a temporary view with the given name + * and then match the table/view from the current database. + * Note that, the global temporary view database is also valid here. * @since 2.0.0 */ def table(tableName: String): DataFrame = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala index 50252db789d4..137b0cbc84f8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala @@ -54,16 +54,16 @@ abstract class Catalog { def listDatabases(): Dataset[Database] /** - * Returns a list of tables in the current database. - * This includes all temporary tables. + * Returns a list of tables/views in the current database. + * This includes all temporary views. * * @since 2.0.0 */ def listTables(): Dataset[Table] /** - * Returns a list of tables in the specified database. - * This includes all temporary tables. + * Returns a list of tables/views in the specified database. + * This includes all temporary views. * * @since 2.0.0 */ @@ -88,17 +88,21 @@ abstract class Catalog { def listFunctions(dbName: String): Dataset[Function] /** - * Returns a list of columns for the given table in the current database or - * the given temporary table. + * Returns a list of columns for the given table/view or temporary view. * + * @param tableName is either a qualified or unqualified name that designates a table/view. + * If no database identifier is provided, it refers to a temporary view or + * a table/view in the current database. * @since 2.0.0 */ @throws[AnalysisException]("table does not exist") def listColumns(tableName: String): Dataset[Column] /** - * Returns a list of columns for the given table in the specified database. + * Returns a list of columns for the given table/view in the specified database. * + * @param dbName is a name that designates a database. + * @param tableName is an unqualified name that designates a table/view. * @since 2.0.0 */ @throws[AnalysisException]("database or table does not exist") @@ -115,9 +119,11 @@ abstract class Catalog { /** * Get the table or view with the specified name. This table can be a temporary view or a - * table/view in the current database. This throws an AnalysisException when no Table - * can be found. + * table/view. This throws an AnalysisException when no Table can be found. * + * @param tableName is either a qualified or unqualified name that designates a table/view. + * If no database identifier is provided, it refers to a table/view in + * the current database. * @since 2.1.0 */ @throws[AnalysisException]("table does not exist") @@ -134,9 +140,11 @@ abstract class Catalog { /** * Get the function with the specified name. This function can be a temporary function or a - * function in the current database. This throws an AnalysisException when the function cannot - * be found. + * function. This throws an AnalysisException when the function cannot be found. * + * @param functionName is either a qualified or unqualified name that designates a function. + * If no database identifier is provided, it refers to a temporary function + * or a function in the current database. * @since 2.1.0 */ @throws[AnalysisException]("function does not exist") @@ -146,6 +154,8 @@ abstract class Catalog { * Get the function with the specified name. This throws an AnalysisException when the function * cannot be found. * + * @param dbName is a name that designates a database. + * @param functionName is an unqualified name that designates a function in the specified database * @since 2.1.0 */ @throws[AnalysisException]("database or function does not exist") @@ -160,8 +170,11 @@ abstract class Catalog { /** * Check if the table or view with the specified name exists. This can either be a temporary - * view or a table/view in the current database. + * view or a table/view. * + * @param tableName is either a qualified or unqualified name that designates a table/view. + * If no database identifier is provided, it refers to a table/view in + * the current database. * @since 2.1.0 */ def tableExists(tableName: String): Boolean @@ -169,14 +182,19 @@ abstract class Catalog { /** * Check if the table or view with the specified name exists in the specified database. * + * @param dbName is a name that designates a database. + * @param tableName is an unqualified name that designates a table. * @since 2.1.0 */ def tableExists(dbName: String, tableName: String): Boolean /** * Check if the function with the specified name exists. This can either be a temporary function - * or a function in the current database. + * or a function. * + * @param functionName is either a qualified or unqualified name that designates a function. + * If no database identifier is provided, it refers to a function in + * the current database. * @since 2.1.0 */ def functionExists(functionName: String): Boolean @@ -184,6 +202,8 @@ abstract class Catalog { /** * Check if the function with the specified name exists in the specified database. * + * @param dbName is a name that designates a database. + * @param functionName is an unqualified name that designates a function. * @since 2.1.0 */ def functionExists(dbName: String, functionName: String): Boolean @@ -192,6 +212,9 @@ abstract class Catalog { * Creates a table from the given path and returns the corresponding DataFrame. * It will use the default data source configured by spark.sql.sources.default. * + * @param tableName is either a qualified or unqualified name that designates a table. + * If no database identifier is provided, it refers to a table in + * the current database. * @since 2.0.0 */ @deprecated("use createTable instead.", "2.2.0") @@ -204,6 +227,9 @@ abstract class Catalog { * Creates a table from the given path and returns the corresponding DataFrame. * It will use the default data source configured by spark.sql.sources.default. * + * @param tableName is either a qualified or unqualified name that designates a table. + * If no database identifier is provided, it refers to a table in + * the current database. * @since 2.2.0 */ @Experimental @@ -214,6 +240,9 @@ abstract class Catalog { * Creates a table from the given path based on a data source and returns the corresponding * DataFrame. * + * @param tableName is either a qualified or unqualified name that designates a table. + * If no database identifier is provided, it refers to a table in + * the current database. * @since 2.0.0 */ @deprecated("use createTable instead.", "2.2.0") @@ -226,6 +255,9 @@ abstract class Catalog { * Creates a table from the given path based on a data source and returns the corresponding * DataFrame. * + * @param tableName is either a qualified or unqualified name that designates a table. + * If no database identifier is provided, it refers to a table in + * the current database. * @since 2.2.0 */ @Experimental @@ -236,6 +268,9 @@ abstract class Catalog { * Creates a table from the given path based on a data source and a set of options. * Then, returns the corresponding DataFrame. * + * @param tableName is either a qualified or unqualified name that designates a table. + * If no database identifier is provided, it refers to a table in + * the current database. * @since 2.0.0 */ @deprecated("use createTable instead.", "2.2.0") @@ -251,6 +286,9 @@ abstract class Catalog { * Creates a table from the given path based on a data source and a set of options. * Then, returns the corresponding DataFrame. * + * @param tableName is either a qualified or unqualified name that designates a table. + * If no database identifier is provided, it refers to a table in + * the current database. * @since 2.2.0 */ @Experimental @@ -267,6 +305,9 @@ abstract class Catalog { * Creates a table from the given path based on a data source and a set of options. * Then, returns the corresponding DataFrame. * + * @param tableName is either a qualified or unqualified name that designates a table. + * If no database identifier is provided, it refers to a table in + * the current database. * @since 2.0.0 */ @deprecated("use createTable instead.", "2.2.0") @@ -283,6 +324,9 @@ abstract class Catalog { * Creates a table from the given path based on a data source and a set of options. * Then, returns the corresponding DataFrame. * + * @param tableName is either a qualified or unqualified name that designates a table. + * If no database identifier is provided, it refers to a table in + * the current database. * @since 2.2.0 */ @Experimental @@ -297,6 +341,9 @@ abstract class Catalog { * Create a table from the given path based on a data source, a schema and a set of options. * Then, returns the corresponding DataFrame. * + * @param tableName is either a qualified or unqualified name that designates a table. + * If no database identifier is provided, it refers to a table in + * the current database. * @since 2.0.0 */ @deprecated("use createTable instead.", "2.2.0") @@ -313,6 +360,9 @@ abstract class Catalog { * Create a table from the given path based on a data source, a schema and a set of options. * Then, returns the corresponding DataFrame. * + * @param tableName is either a qualified or unqualified name that designates a table. + * If no database identifier is provided, it refers to a table in + * the current database. * @since 2.2.0 */ @Experimental @@ -330,6 +380,9 @@ abstract class Catalog { * Create a table from the given path based on a data source, a schema and a set of options. * Then, returns the corresponding DataFrame. * + * @param tableName is either a qualified or unqualified name that designates a table. + * If no database identifier is provided, it refers to a table in + * the current database. * @since 2.0.0 */ @deprecated("use createTable instead.", "2.2.0") @@ -347,6 +400,9 @@ abstract class Catalog { * Create a table from the given path based on a data source, a schema and a set of options. * Then, returns the corresponding DataFrame. * + * @param tableName is either a qualified or unqualified name that designates a table. + * If no database identifier is provided, it refers to a table in + * the current database. * @since 2.2.0 */ @Experimental @@ -368,7 +424,7 @@ abstract class Catalog { * Note that, the return type of this method was Unit in Spark 2.0, but changed to Boolean * in Spark 2.1. * - * @param viewName the name of the view to be dropped. + * @param viewName the name of the temporary view to be dropped. * @return true if the view is dropped successfully, false otherwise. * @since 2.0.0 */ @@ -383,15 +439,18 @@ abstract class Catalog { * preserved database `global_temp`, and we must use the qualified name to refer a global temp * view, e.g. `SELECT * FROM global_temp.view1`. * - * @param viewName the name of the view to be dropped. + * @param viewName the unqualified name of the temporary view to be dropped. * @return true if the view is dropped successfully, false otherwise. * @since 2.1.0 */ def dropGlobalTempView(viewName: String): Boolean /** - * Recover all the partitions in the directory of a table and update the catalog. + * Recovers all the partitions in the directory of a table and update the catalog. * + * @param tableName is either a qualified or unqualified name that designates a table. + * If no database identifier is provided, it refers to a table in the + * current database. * @since 2.1.1 */ def recoverPartitions(tableName: String): Unit @@ -399,6 +458,9 @@ abstract class Catalog { /** * Returns true if the table is currently cached in-memory. * + * @param tableName is either a qualified or unqualified name that designates a table/view. + * If no database identifier is provided, it refers to a temporary view or + * a table/view in the current database. * @since 2.0.0 */ def isCached(tableName: String): Boolean @@ -406,6 +468,9 @@ abstract class Catalog { /** * Caches the specified table in-memory. * + * @param tableName is either a qualified or unqualified name that designates a table/view. + * If no database identifier is provided, it refers to a temporary view or + * a table/view in the current database. * @since 2.0.0 */ def cacheTable(tableName: String): Unit @@ -413,6 +478,9 @@ abstract class Catalog { /** * Removes the specified table from the in-memory cache. * + * @param tableName is either a qualified or unqualified name that designates a table/view. + * If no database identifier is provided, it refers to a temporary view or + * a table/view in the current database. * @since 2.0.0 */ def uncacheTable(tableName: String): Unit @@ -425,7 +493,7 @@ abstract class Catalog { def clearCache(): Unit /** - * Invalidate and refresh all the cached metadata of the given table. For performance reasons, + * Invalidates and refreshes all the cached metadata of the given table. For performance reasons, * Spark SQL or the external data source library it uses might cache certain metadata about a * table, such as the location of blocks. When those change outside of Spark SQL, users should * call this function to invalidate the cache. @@ -433,13 +501,16 @@ abstract class Catalog { * If this table is cached as an InMemoryRelation, drop the original cached version and make the * new version cached lazily. * + * @param tableName is either a qualified or unqualified name that designates a table/view. + * If no database identifier is provided, it refers to a temporary view or + * a table/view in the current database. * @since 2.0.0 */ def refreshTable(tableName: String): Unit /** - * Invalidate and refresh all the cached data (and the associated metadata) for any dataframe that - * contains the given data source path. Path matching is by prefix, i.e. "/" would invalidate + * Invalidates and refreshes all the cached data (and the associated metadata) for any [[Dataset]] + * that contains the given data source path. Path matching is by prefix, i.e. "/" would invalidate * everything that is cached. * * @since 2.0.0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala index 53374859f13f..5d1c35aba529 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala @@ -19,8 +19,6 @@ package org.apache.spark.sql.internal import scala.reflect.runtime.universe.TypeTag -import org.apache.hadoop.fs.Path - import org.apache.spark.annotation.Experimental import org.apache.spark.sql._ import org.apache.spark.sql.catalog.{Catalog, Column, Database, Function, Table} @@ -143,11 +141,12 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { } /** - * Returns a list of columns for the given table in the current database. + * Returns a list of columns for the given table temporary view. */ @throws[AnalysisException]("table does not exist") override def listColumns(tableName: String): Dataset[Column] = { - listColumns(TableIdentifier(tableName, None)) + val tableIdent = sparkSession.sessionState.sqlParser.parseTableIdentifier(tableName) + listColumns(tableIdent) } /** @@ -177,7 +176,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { } /** - * Get the database with the specified name. This throws an `AnalysisException` when no + * Gets the database with the specified name. This throws an `AnalysisException` when no * `Database` can be found. */ override def getDatabase(dbName: String): Database = { @@ -185,16 +184,16 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { } /** - * Get the table or view with the specified name. This table can be a temporary view or a - * table/view in the current database. This throws an `AnalysisException` when no `Table` - * can be found. + * Gets the table or view with the specified name. This table can be a temporary view or a + * table/view. This throws an `AnalysisException` when no `Table` can be found. */ override def getTable(tableName: String): Table = { - getTable(null, tableName) + val tableIdent = sparkSession.sessionState.sqlParser.parseTableIdentifier(tableName) + getTable(tableIdent.database.orNull, tableIdent.table) } /** - * Get the table or view with the specified name in the specified database. This throws an + * Gets the table or view with the specified name in the specified database. This throws an * `AnalysisException` when no `Table` can be found. */ override def getTable(dbName: String, tableName: String): Table = { @@ -202,16 +201,16 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { } /** - * Get the function with the specified name. This function can be a temporary function or a - * function in the current database. This throws an `AnalysisException` when no `Function` - * can be found. + * Gets the function with the specified name. This function can be a temporary function or a + * function. This throws an `AnalysisException` when no `Function` can be found. */ override def getFunction(functionName: String): Function = { - getFunction(null, functionName) + val functionIdent = sparkSession.sessionState.sqlParser.parseFunctionIdentifier(functionName) + getFunction(functionIdent.database.orNull, functionIdent.funcName) } /** - * Get the function with the specified name. This returns `None` when no `Function` can be + * Gets the function with the specified name. This returns `None` when no `Function` can be * found. */ override def getFunction(dbName: String, functionName: String): Function = { @@ -219,22 +218,23 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { } /** - * Check if the database with the specified name exists. + * Checks if the database with the specified name exists. */ override def databaseExists(dbName: String): Boolean = { sessionCatalog.databaseExists(dbName) } /** - * Check if the table or view with the specified name exists. This can either be a temporary - * view or a table/view in the current database. + * Checks if the table or view with the specified name exists. This can either be a temporary + * view or a table/view. */ override def tableExists(tableName: String): Boolean = { - tableExists(null, tableName) + val tableIdent = sparkSession.sessionState.sqlParser.parseTableIdentifier(tableName) + tableExists(tableIdent.database.orNull, tableIdent.table) } /** - * Check if the table or view with the specified name exists in the specified database. + * Checks if the table or view with the specified name exists in the specified database. */ override def tableExists(dbName: String, tableName: String): Boolean = { val tableIdent = TableIdentifier(tableName, Option(dbName)) @@ -242,15 +242,16 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { } /** - * Check if the function with the specified name exists. This can either be a temporary function - * or a function in the current database. + * Checks if the function with the specified name exists. This can either be a temporary function + * or a function. */ override def functionExists(functionName: String): Boolean = { - functionExists(null, functionName) + val functionIdent = sparkSession.sessionState.sqlParser.parseFunctionIdentifier(functionName) + functionExists(functionIdent.database.orNull, functionIdent.funcName) } /** - * Check if the function with the specified name exists in the specified database. + * Checks if the function with the specified name exists in the specified database. */ override def functionExists(dbName: String, functionName: String): Boolean = { sessionCatalog.functionExists(FunctionIdentifier(functionName, Option(dbName))) @@ -303,7 +304,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { /** * :: Experimental :: * (Scala-specific) - * Create a table from the given path based on a data source, a schema and a set of options. + * Creates a table from the given path based on a data source, a schema and a set of options. * Then, returns the corresponding DataFrame. * * @group ddl_ops @@ -338,7 +339,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { * Drops the local temporary view with the given view name in the catalog. * If the view has been cached/persisted before, it's also unpersisted. * - * @param viewName the name of the view to be dropped. + * @param viewName the identifier of the temporary view to be dropped. * @group ddl_ops * @since 2.0.0 */ @@ -353,7 +354,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { * Drops the global temporary view with the given view name in the catalog. * If the view has been cached/persisted before, it's also unpersisted. * - * @param viewName the name of the view to be dropped. + * @param viewName the identifier of the global temporary view to be dropped. * @group ddl_ops * @since 2.1.0 */ @@ -365,9 +366,11 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { } /** - * Recover all the partitions in the directory of a table and update the catalog. + * Recovers all the partitions in the directory of a table and update the catalog. * - * @param tableName the name of the table to be repaired. + * @param tableName is either a qualified or unqualified name that designates a table. + * If no database identifier is provided, it refers to a table in the + * current database. * @group ddl_ops * @since 2.1.1 */ @@ -378,7 +381,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { } /** - * Returns true if the table is currently cached in-memory. + * Returns true if the table or view is currently cached in-memory. * * @group cachemgmt * @since 2.0.0 @@ -388,7 +391,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { } /** - * Caches the specified table in-memory. + * Caches the specified table or view in-memory. * * @group cachemgmt * @since 2.0.0 @@ -398,7 +401,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { } /** - * Removes the specified table from the in-memory cache. + * Removes the specified table or view from the in-memory cache. * * @group cachemgmt * @since 2.0.0 @@ -408,7 +411,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { } /** - * Removes all cached tables from the in-memory cache. + * Removes all cached tables or views from the in-memory cache. * * @group cachemgmt * @since 2.0.0 @@ -428,7 +431,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { } /** - * Refresh the cache entry for a table, if any. For Hive metastore table, the metadata + * Refreshes the cache entry for a table or view, if any. For Hive metastore table, the metadata * is refreshed. For data source tables, the schema will not be inferred and refreshed. * * @group cachemgmt @@ -452,7 +455,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { } /** - * Refresh the cache entry and the associated metadata for all dataframes (if any), that contain + * Refreshes the cache entry and the associated metadata for all Dataset (if any), that contain * the given data source path. * * @group cachemgmt diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala index 9742b3b2d5c2..6469e501c1f6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala @@ -102,6 +102,11 @@ class CatalogSuite assert(col.isPartition == tableMetadata.partitionColumnNames.contains(col.name)) assert(col.isBucket == bucketColumnNames.contains(col.name)) } + + dbName.foreach { db => + val expected = columns.collect().map(_.name).toSet + assert(spark.catalog.listColumns(s"$db.$tableName").collect().map(_.name).toSet == expected) + } } override def afterEach(): Unit = { @@ -345,6 +350,7 @@ class CatalogSuite // Find a qualified table assert(spark.catalog.getTable(db, "tbl_y").name === "tbl_y") + assert(spark.catalog.getTable(s"$db.tbl_y").name === "tbl_y") // Find an unqualified table using the current database intercept[AnalysisException](spark.catalog.getTable("tbl_y")) @@ -378,6 +384,11 @@ class CatalogSuite assert(fn2.database === db) assert(!fn2.isTemporary) + val fn2WithQualifiedName = spark.catalog.getFunction(s"$db.fn2") + assert(fn2WithQualifiedName.name === "fn2") + assert(fn2WithQualifiedName.database === db) + assert(!fn2WithQualifiedName.isTemporary) + // Find an unqualified function using the current database intercept[AnalysisException](spark.catalog.getFunction("fn2")) spark.catalog.setCurrentDatabase(db) @@ -403,6 +414,7 @@ class CatalogSuite assert(!spark.catalog.tableExists("tbl_x")) assert(!spark.catalog.tableExists("tbl_y")) assert(!spark.catalog.tableExists(db, "tbl_y")) + assert(!spark.catalog.tableExists(s"$db.tbl_y")) // Create objects. createTempTable("tbl_x") @@ -413,11 +425,15 @@ class CatalogSuite // Find a qualified table assert(spark.catalog.tableExists(db, "tbl_y")) + assert(spark.catalog.tableExists(s"$db.tbl_y")) // Find an unqualified table using the current database assert(!spark.catalog.tableExists("tbl_y")) spark.catalog.setCurrentDatabase(db) assert(spark.catalog.tableExists("tbl_y")) + + // Unable to find the table, although the temp view with the given name exists + assert(!spark.catalog.tableExists(db, "tbl_x")) } } } @@ -429,6 +445,7 @@ class CatalogSuite assert(!spark.catalog.functionExists("fn1")) assert(!spark.catalog.functionExists("fn2")) assert(!spark.catalog.functionExists(db, "fn2")) + assert(!spark.catalog.functionExists(s"$db.fn2")) // Create objects. createTempFunction("fn1") @@ -439,11 +456,15 @@ class CatalogSuite // Find a qualified function assert(spark.catalog.functionExists(db, "fn2")) + assert(spark.catalog.functionExists(s"$db.fn2")) // Find an unqualified function using the current database assert(!spark.catalog.functionExists("fn2")) spark.catalog.setCurrentDatabase(db) assert(spark.catalog.functionExists("fn2")) + + // Unable to find the function, although the temp function with the given name exists + assert(!spark.catalog.functionExists(db, "fn1")) } } } From 11238d4c62961c03376d9b2899221ec74313363a Mon Sep 17 00:00:00 2001 From: Anirudh Ramanathan Date: Tue, 4 Apr 2017 10:46:44 -0700 Subject: [PATCH 204/512] [SPARK-18278][SCHEDULER] Documentation to point to Kubernetes cluster scheduler ## What changes were proposed in this pull request? Adding documentation to point to Kubernetes cluster scheduler being developed out-of-repo in https://github.com/apache-spark-on-k8s/spark cc rxin srowen tnachen ash211 mccheah erikerlandson ## How was this patch tested? Docs only change Author: Anirudh Ramanathan Author: foxish Closes #17522 from foxish/upstream-doc. --- docs/cluster-overview.md | 6 +++++- docs/index.md | 1 + 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/docs/cluster-overview.md b/docs/cluster-overview.md index 814e4406cf43..a2ad958959a5 100644 --- a/docs/cluster-overview.md +++ b/docs/cluster-overview.md @@ -52,7 +52,11 @@ The system currently supports three cluster managers: * [Apache Mesos](running-on-mesos.html) -- a general cluster manager that can also run Hadoop MapReduce and service applications. * [Hadoop YARN](running-on-yarn.html) -- the resource manager in Hadoop 2. - +* [Kubernetes (experimental)](https://github.com/apache-spark-on-k8s/spark) -- In addition to the above, +there is experimental support for Kubernetes. Kubernetes is an open-source platform +for providing container-centric infrastructure. Kubernetes support is being actively +developed in an [apache-spark-on-k8s](https://github.com/apache-spark-on-k8s/) Github organization. +For documentation, refer to that project's README. # Submitting Applications diff --git a/docs/index.md b/docs/index.md index 19a9d3bfc601..ad4f24ff1a5d 100644 --- a/docs/index.md +++ b/docs/index.md @@ -115,6 +115,7 @@ options for deployment: * [Mesos](running-on-mesos.html): deploy a private cluster using [Apache Mesos](http://mesos.apache.org) * [YARN](running-on-yarn.html): deploy Spark on top of Hadoop NextGen (YARN) + * [Kubernetes (experimental)](https://github.com/apache-spark-on-k8s/spark): deploy Spark on top of Kubernetes **Other Documents:** From 0736980f395f114faccbd58e78280ca63ed289c7 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Tue, 4 Apr 2017 11:38:05 -0700 Subject: [PATCH 205/512] [SPARK-20191][YARN] Crate wrapper for RackResolver so tests can override it. Current test code tries to override the RackResolver used by setting configuration params, but because YARN libs statically initialize the resolver the first time it's used, that means that those configs don't really take effect during Spark tests. This change adds a wrapper class that easily allows tests to override the behavior of the resolver for the Spark code that uses it. Author: Marcelo Vanzin Closes #17508 from vanzin/SPARK-20191. --- ...yPreferredContainerPlacementStrategy.scala | 6 +-- .../spark/deploy/yarn/SparkRackResolver.scala | 40 +++++++++++++++++++ .../spark/deploy/yarn/YarnAllocator.scala | 13 ++---- .../spark/deploy/yarn/YarnRMClient.scala | 2 +- .../yarn/LocalityPlacementStrategySuite.scala | 8 +--- .../deploy/yarn/YarnAllocatorSuite.scala | 22 +++------- 6 files changed, 56 insertions(+), 35 deletions(-) create mode 100644 resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/SparkRackResolver.scala diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/LocalityPreferredContainerPlacementStrategy.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/LocalityPreferredContainerPlacementStrategy.scala index f2b6324db619..257dc83621e9 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/LocalityPreferredContainerPlacementStrategy.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/LocalityPreferredContainerPlacementStrategy.scala @@ -23,7 +23,6 @@ import scala.collection.JavaConverters._ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.yarn.api.records.{ContainerId, Resource} import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest -import org.apache.hadoop.yarn.util.RackResolver import org.apache.spark.SparkConf import org.apache.spark.internal.config._ @@ -83,7 +82,8 @@ private[yarn] case class ContainerLocalityPreferences(nodes: Array[String], rack private[yarn] class LocalityPreferredContainerPlacementStrategy( val sparkConf: SparkConf, val yarnConf: Configuration, - val resource: Resource) { + val resource: Resource, + resolver: SparkRackResolver) { /** * Calculate each container's node locality and rack locality @@ -139,7 +139,7 @@ private[yarn] class LocalityPreferredContainerPlacementStrategy( // still be allocated with new container request. val hosts = preferredLocalityRatio.filter(_._2 > 0).keys.toArray val racks = hosts.map { h => - RackResolver.resolve(yarnConf, h).getNetworkLocation + resolver.resolve(yarnConf, h) }.toSet containerLocalityPreferences += ContainerLocalityPreferences(hosts, racks.toArray) diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/SparkRackResolver.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/SparkRackResolver.scala new file mode 100644 index 000000000000..c711d088f211 --- /dev/null +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/SparkRackResolver.scala @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.yarn + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.yarn.util.RackResolver +import org.apache.log4j.{Level, Logger} + +/** + * Wrapper around YARN's [[RackResolver]]. This allows Spark tests to easily override the + * default behavior, since YARN's class self-initializes the first time it's called, and + * future calls all use the initial configuration. + */ +private[yarn] class SparkRackResolver { + + // RackResolver logs an INFO message whenever it resolves a rack, which is way too often. + if (Logger.getLogger(classOf[RackResolver]).getLevel == null) { + Logger.getLogger(classOf[RackResolver]).setLevel(Level.WARN) + } + + def resolve(conf: Configuration, hostName: String): String = { + RackResolver.resolve(conf, hostName).getNetworkLocation() + } + +} diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala index 25556763da90..ed77a6e4a1c7 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala @@ -30,7 +30,6 @@ import org.apache.hadoop.yarn.api.records._ import org.apache.hadoop.yarn.client.api.AMRMClient import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest import org.apache.hadoop.yarn.conf.YarnConfiguration -import org.apache.hadoop.yarn.util.RackResolver import org.apache.log4j.{Level, Logger} import org.apache.spark.{SecurityManager, SparkConf, SparkException} @@ -65,16 +64,12 @@ private[yarn] class YarnAllocator( amClient: AMRMClient[ContainerRequest], appAttemptId: ApplicationAttemptId, securityMgr: SecurityManager, - localResources: Map[String, LocalResource]) + localResources: Map[String, LocalResource], + resolver: SparkRackResolver) extends Logging { import YarnAllocator._ - // RackResolver logs an INFO message whenever it resolves a rack, which is way too often. - if (Logger.getLogger(classOf[RackResolver]).getLevel == null) { - Logger.getLogger(classOf[RackResolver]).setLevel(Level.WARN) - } - // Visible for testing. val allocatedHostToContainersMap = new HashMap[String, collection.mutable.Set[ContainerId]] val allocatedContainerToHostMap = new HashMap[ContainerId, String] @@ -159,7 +154,7 @@ private[yarn] class YarnAllocator( // A container placement strategy based on pending tasks' locality preference private[yarn] val containerPlacementStrategy = - new LocalityPreferredContainerPlacementStrategy(sparkConf, conf, resource) + new LocalityPreferredContainerPlacementStrategy(sparkConf, conf, resource, resolver) /** * Use a different clock for YarnAllocator. This is mainly used for testing. @@ -424,7 +419,7 @@ private[yarn] class YarnAllocator( // Match remaining by rack val remainingAfterRackMatches = new ArrayBuffer[Container] for (allocatedContainer <- remainingAfterHostMatches) { - val rack = RackResolver.resolve(conf, allocatedContainer.getNodeId.getHost).getNetworkLocation + val rack = resolver.resolve(conf, allocatedContainer.getNodeId.getHost) matchContainerToRequest(allocatedContainer, rack, containersToUse, remainingAfterRackMatches) } diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala index 53fb467f6408..72f4d273ab53 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala @@ -75,7 +75,7 @@ private[spark] class YarnRMClient extends Logging { registered = true } new YarnAllocator(driverUrl, driverRef, conf, sparkConf, amClient, getAttemptId(), securityMgr, - localResources) + localResources, new SparkRackResolver()) } /** diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/LocalityPlacementStrategySuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/LocalityPlacementStrategySuite.scala index fb80ff9f3132..b7f25656e49a 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/LocalityPlacementStrategySuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/LocalityPlacementStrategySuite.scala @@ -17,10 +17,9 @@ package org.apache.spark.deploy.yarn +import scala.collection.JavaConverters._ import scala.collection.mutable.{HashMap, HashSet, Set} -import org.apache.hadoop.fs.CommonConfigurationKeysPublic -import org.apache.hadoop.net.DNSToSwitchMapping import org.apache.hadoop.yarn.api.records._ import org.apache.hadoop.yarn.conf.YarnConfiguration import org.mockito.Mockito._ @@ -51,9 +50,6 @@ class LocalityPlacementStrategySuite extends SparkFunSuite { private def runTest(): Unit = { val yarnConf = new YarnConfiguration() - yarnConf.setClass( - CommonConfigurationKeysPublic.NET_TOPOLOGY_NODE_SWITCH_MAPPING_IMPL_KEY, - classOf[MockResolver], classOf[DNSToSwitchMapping]) // The numbers below have been chosen to balance being large enough to replicate the // original issue while not taking too long to run when the issue is fixed. The main @@ -62,7 +58,7 @@ class LocalityPlacementStrategySuite extends SparkFunSuite { val resource = Resource.newInstance(8 * 1024, 4) val strategy = new LocalityPreferredContainerPlacementStrategy(new SparkConf(), - yarnConf, resource) + yarnConf, resource, new MockResolver()) val totalTasks = 32 * 1024 val totalContainers = totalTasks / 16 diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala index fcc0594cf6d8..97b0e8aca333 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala @@ -17,12 +17,9 @@ package org.apache.spark.deploy.yarn -import java.util.{Arrays, List => JList} - import scala.collection.JavaConverters._ -import org.apache.hadoop.fs.CommonConfigurationKeysPublic -import org.apache.hadoop.net.DNSToSwitchMapping +import org.apache.hadoop.conf.Configuration import org.apache.hadoop.yarn.api.records._ import org.apache.hadoop.yarn.client.api.AMRMClient import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest @@ -38,24 +35,16 @@ import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.scheduler.SplitInfo import org.apache.spark.util.ManualClock -class MockResolver extends DNSToSwitchMapping { +class MockResolver extends SparkRackResolver { - override def resolve(names: JList[String]): JList[String] = { - if (names.size > 0 && names.get(0) == "host3") Arrays.asList("/rack2") - else Arrays.asList("/rack1") + override def resolve(conf: Configuration, hostName: String): String = { + if (hostName == "host3") "/rack2" else "/rack1" } - override def reloadCachedMappings() {} - - def reloadCachedMappings(names: JList[String]) {} } class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfterEach { val conf = new YarnConfiguration() - conf.setClass( - CommonConfigurationKeysPublic.NET_TOPOLOGY_NODE_SWITCH_MAPPING_IMPL_KEY, - classOf[MockResolver], classOf[DNSToSwitchMapping]) - val sparkConf = new SparkConf() sparkConf.set("spark.driver.host", "localhost") sparkConf.set("spark.driver.port", "4040") @@ -111,7 +100,8 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter rmClient, appAttemptId, new SecurityManager(sparkConf), - Map()) + Map(), + new MockResolver()) } def createContainer(host: String): Container = { From 0e2ee8204415d28613a60593f2b6e2b3d4ef794f Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Tue, 4 Apr 2017 11:42:14 -0700 Subject: [PATCH 206/512] [MINOR][R] Reorder `Collate` fields in DESCRIPTION file ## What changes were proposed in this pull request? It seems cran check scripts corrects `R/pkg/DESCRIPTION` and follows the order in `Collate` fields. This PR proposes to fix `catalog.R`'s order so that running this script does not show up a small diff in this file every time. ## How was this patch tested? Manually via `./R/check-cran.sh`. Author: hyukjinkwon Closes #17528 from HyukjinKwon/minor-reorder-description. --- R/pkg/DESCRIPTION | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/pkg/DESCRIPTION b/R/pkg/DESCRIPTION index f475ee87702e..879c1f80f2c5 100644 --- a/R/pkg/DESCRIPTION +++ b/R/pkg/DESCRIPTION @@ -32,10 +32,10 @@ Collate: 'pairRDD.R' 'DataFrame.R' 'SQLContext.R' - 'catalog.R' 'WindowSpec.R' 'backend.R' 'broadcast.R' + 'catalog.R' 'client.R' 'context.R' 'deserialize.R' From 402bf2a50ddd4039ff9f376b641bd18fffa54171 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 4 Apr 2017 11:56:21 -0700 Subject: [PATCH 207/512] [SPARK-20204][SQL] remove SimpleCatalystConf and CatalystConf type alias ## What changes were proposed in this pull request? This is a follow-up of https://github.com/apache/spark/pull/17285 . ## How was this patch tested? existing tests Author: Wenchen Fan Closes #17521 from cloud-fan/conf. --- .../sql/catalyst/SimpleCatalystConf.scala | 51 ------------------- .../sql/catalyst/analysis/Analyzer.scala | 21 ++++---- .../sql/catalyst/analysis/ResolveHints.scala | 4 +- .../analysis/ResolveInlineTables.scala | 5 +- .../SubstituteUnresolvedOrdinals.scala | 4 +- .../spark/sql/catalyst/analysis/view.scala | 4 +- .../sql/catalyst/catalog/SessionCatalog.scala | 7 +-- .../sql/catalyst/catalog/interface.scala | 5 +- .../sql/catalyst/optimizer/Optimizer.scala | 18 +++---- .../sql/catalyst/optimizer/expressions.scala | 8 +-- .../spark/sql/catalyst/optimizer/joins.scala | 3 +- .../apache/spark/sql/catalyst/package.scala | 8 --- .../plans/logical/LocalRelation.scala | 5 +- .../catalyst/plans/logical/LogicalPlan.scala | 8 +-- .../plans/logical/basicLogicalOperators.scala | 32 ++++++------ .../statsEstimation/AggregateEstimation.scala | 4 +- .../statsEstimation/EstimationUtils.scala | 4 +- .../statsEstimation/FilterEstimation.scala | 4 +- .../statsEstimation/JoinEstimation.scala | 8 +-- .../statsEstimation/ProjectEstimation.scala | 4 +- .../apache/spark/sql/internal/SQLConf.scala | 9 ++++ .../sql/catalyst/analysis/AnalysisTest.scala | 4 +- .../analysis/DecimalPrecisionSuite.scala | 1 - .../SubstituteUnresolvedOrdinalsSuite.scala | 6 +-- .../catalog/SessionCatalogSuite.scala | 5 +- .../optimizer/AggregateOptimizeSuite.scala | 5 +- .../BinaryComparisonSimplificationSuite.scala | 2 - .../BooleanSimplificationSuite.scala | 5 +- .../optimizer/CombiningLimitsSuite.scala | 3 +- .../optimizer/ConstantFoldingSuite.scala | 3 +- .../optimizer/DecimalAggregatesSuite.scala | 3 +- .../optimizer/EliminateSortsSuite.scala | 5 +- .../InferFiltersFromConstraintsSuite.scala | 7 ++- .../optimizer/JoinOptimizationSuite.scala | 3 +- .../catalyst/optimizer/JoinReorderSuite.scala | 7 +-- .../optimizer/LimitPushdownSuite.scala | 1 - .../optimizer/OptimizeCodegenSuite.scala | 3 +- .../catalyst/optimizer/OptimizeInSuite.scala | 11 ++-- .../optimizer/OuterJoinEliminationSuite.scala | 7 ++- .../PropagateEmptyRelationSuite.scala | 5 +- .../optimizer/PruneFiltersSuite.scala | 7 ++- .../RewriteDistinctAggregatesSuite.scala | 9 ++-- .../optimizer/SetOperationSuite.scala | 3 +- .../optimizer/StarJoinReorderSuite.scala | 7 ++- .../spark/sql/catalyst/plans/PlanTest.scala | 4 +- .../AggregateEstimationSuite.scala | 5 +- .../BasicStatsEstimationSuite.scala | 8 +-- .../StatsEstimationTestBase.scala | 7 +-- .../spark/sql/execution/ExistingRDD.scala | 7 +-- .../execution/columnar/InMemoryRelation.scala | 5 +- .../datasources/DataSourceStrategy.scala | 8 ++- .../datasources/LogicalRelation.scala | 4 +- .../sql/execution/streaming/memory.scala | 4 +- .../sql/sources/DataSourceAnalysisSuite.scala | 4 +- 54 files changed, 164 insertions(+), 220 deletions(-) delete mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SimpleCatalystConf.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SimpleCatalystConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SimpleCatalystConf.scala deleted file mode 100644 index 8498cf1c9be7..000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SimpleCatalystConf.scala +++ /dev/null @@ -1,51 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst - -import java.util.TimeZone - -import org.apache.spark.sql.internal.SQLConf - - -/** - * A SQLConf that can be used for local testing. This class is only here to minimize the change - * for ticket SPARK-19944 (moves SQLConf from sql/core to sql/catalyst). This class should - * eventually be removed (test cases should just create SQLConf and set values appropriately). - */ -case class SimpleCatalystConf( - override val caseSensitiveAnalysis: Boolean, - override val orderByOrdinal: Boolean = true, - override val groupByOrdinal: Boolean = true, - override val optimizerMaxIterations: Int = 100, - override val optimizerInSetConversionThreshold: Int = 10, - override val maxCaseBranchesForCodegen: Int = 20, - override val tableRelationCacheSize: Int = 1000, - override val runSQLonFile: Boolean = true, - override val crossJoinEnabled: Boolean = false, - override val cboEnabled: Boolean = false, - override val joinReorderEnabled: Boolean = false, - override val joinReorderDPThreshold: Int = 12, - override val starSchemaDetection: Boolean = false, - override val warehousePath: String = "/user/hive/warehouse", - override val sessionLocalTimeZone: String = TimeZone.getDefault().getID, - override val maxNestedViewDepth: Int = 100, - override val constraintPropagationEnabled: Boolean = true) - extends SQLConf { - - override def clone(): SimpleCatalystConf = this.copy() -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 1b3a53c6359e..2d53d2424a34 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -34,6 +34,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, _} import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.catalyst.trees.TreeNodeRef import org.apache.spark.sql.catalyst.util.toPrettySQL +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ /** @@ -42,13 +43,13 @@ import org.apache.spark.sql.types._ * to resolve attribute references. */ object SimpleAnalyzer extends Analyzer( - new SessionCatalog( - new InMemoryCatalog, - EmptyFunctionRegistry, - new SimpleCatalystConf(caseSensitiveAnalysis = true)) { - override def createDatabase(dbDefinition: CatalogDatabase, ignoreIfExists: Boolean) {} - }, - new SimpleCatalystConf(caseSensitiveAnalysis = true)) + new SessionCatalog( + new InMemoryCatalog, + EmptyFunctionRegistry, + new SQLConf().copy(SQLConf.CASE_SENSITIVE -> true)) { + override def createDatabase(dbDefinition: CatalogDatabase, ignoreIfExists: Boolean) {} + }, + new SQLConf().copy(SQLConf.CASE_SENSITIVE -> true)) /** * Provides a way to keep state during the analysis, this enables us to decouple the concerns @@ -89,11 +90,11 @@ object AnalysisContext { */ class Analyzer( catalog: SessionCatalog, - conf: CatalystConf, + conf: SQLConf, maxIterations: Int) extends RuleExecutor[LogicalPlan] with CheckAnalysis { - def this(catalog: SessionCatalog, conf: CatalystConf) = { + def this(catalog: SessionCatalog, conf: SQLConf) = { this(catalog, conf, conf.optimizerMaxIterations) } @@ -2331,7 +2332,7 @@ class Analyzer( } /** - * Replace [[TimeZoneAwareExpression]] without [[TimeZone]] by its copy with session local + * Replace [[TimeZoneAwareExpression]] without timezone id by its copy with session local * time zone. */ object ResolveTimeZone extends Rule[LogicalPlan] { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala index 920033a9a848..f8004ca300ac 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala @@ -17,10 +17,10 @@ package org.apache.spark.sql.catalyst.analysis -import org.apache.spark.sql.catalyst.CatalystConf import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.trees.CurrentOrigin +import org.apache.spark.sql.internal.SQLConf /** @@ -43,7 +43,7 @@ object ResolveHints { * * This rule must happen before common table expressions. */ - class ResolveBroadcastHints(conf: CatalystConf) extends Rule[LogicalPlan] { + class ResolveBroadcastHints(conf: SQLConf) extends Rule[LogicalPlan] { private val BROADCAST_HINT_NAMES = Set("BROADCAST", "BROADCASTJOIN", "MAPJOIN") def resolver: Resolver = conf.resolver diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala index d5b3ea8c37c6..a991dd96e282 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala @@ -19,16 +19,17 @@ package org.apache.spark.sql.catalyst.analysis import scala.util.control.NonFatal -import org.apache.spark.sql.catalyst.{CatalystConf, InternalRow} +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Cast, TimeZoneAwareExpression} import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{StructField, StructType} /** * An analyzer rule that replaces [[UnresolvedInlineTable]] with [[LocalRelation]]. */ -case class ResolveInlineTables(conf: CatalystConf) extends Rule[LogicalPlan] { +case class ResolveInlineTables(conf: SQLConf) extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case table: UnresolvedInlineTable if table.expressionsResolved => validateInputDimension(table) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/SubstituteUnresolvedOrdinals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/SubstituteUnresolvedOrdinals.scala index 38a3d3de1288..256b18771052 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/SubstituteUnresolvedOrdinals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/SubstituteUnresolvedOrdinals.scala @@ -17,17 +17,17 @@ package org.apache.spark.sql.catalyst.analysis -import org.apache.spark.sql.catalyst.CatalystConf import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, SortOrder} import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Sort} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.trees.CurrentOrigin.withOrigin +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.IntegerType /** * Replaces ordinal in 'order by' or 'group by' with UnresolvedOrdinal expression. */ -class SubstituteUnresolvedOrdinals(conf: CatalystConf) extends Rule[LogicalPlan] { +class SubstituteUnresolvedOrdinals(conf: SQLConf) extends Rule[LogicalPlan] { private def isIntLiteral(e: Expression) = e match { case Literal(_, IntegerType) => true case _ => false diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/view.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/view.scala index a5640a6c967a..3bd54c257d98 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/view.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/view.scala @@ -18,10 +18,10 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.CatalystConf import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Cast} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, View} import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.internal.SQLConf /** * This file defines analysis rules related to views. @@ -47,7 +47,7 @@ import org.apache.spark.sql.catalyst.rules.Rule * This should be only done after the batch of Resolution, because the view attributes are not * completely resolved during the batch of Resolution. */ -case class AliasViewChild(conf: CatalystConf) extends Rule[LogicalPlan] { +case class AliasViewChild(conf: SQLConf) extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case v @ View(desc, output, child) if child.resolved && output != child.output => val resolver = conf.resolver diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index 72ab07540889..6f8c6ee2f0f4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -35,6 +35,7 @@ import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionInfo} import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParserInterface} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias, View} import org.apache.spark.sql.catalyst.util.StringUtils +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{StructField, StructType} object SessionCatalog { @@ -52,7 +53,7 @@ class SessionCatalog( val externalCatalog: ExternalCatalog, globalTempViewManager: GlobalTempViewManager, functionRegistry: FunctionRegistry, - conf: CatalystConf, + conf: SQLConf, hadoopConf: Configuration, parser: ParserInterface, functionResourceLoader: FunctionResourceLoader) extends Logging { @@ -63,7 +64,7 @@ class SessionCatalog( def this( externalCatalog: ExternalCatalog, functionRegistry: FunctionRegistry, - conf: CatalystConf) { + conf: SQLConf) { this( externalCatalog, new GlobalTempViewManager("global_temp"), @@ -79,7 +80,7 @@ class SessionCatalog( this( externalCatalog, new SimpleFunctionRegistry, - SimpleCatalystConf(caseSensitiveAnalysis = true)) + new SQLConf().copy(SQLConf.CASE_SENSITIVE -> true)) } /** List of temporary tables, mapping from table name to their logical plan. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala index 3f25f9e7258f..dc2e40424fd5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala @@ -25,12 +25,13 @@ import scala.collection.mutable import com.google.common.base.Objects import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.{CatalystConf, FunctionIdentifier, InternalRow, TableIdentifier} +import org.apache.spark.sql.catalyst.{FunctionIdentifier, InternalRow, TableIdentifier} import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, Cast, Literal} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} import org.apache.spark.sql.catalyst.util.quoteIdentifier +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.StructType @@ -425,7 +426,7 @@ case class CatalogRelation( /** Only compare table identifier. */ override lazy val cleanArgs: Seq[Any] = Seq(tableMeta.identifier) - override def computeStats(conf: CatalystConf): Statistics = { + override def computeStats(conf: SQLConf): Statistics = { // For data source tables, we will create a `LogicalRelation` and won't call this method, for // hive serde tables, we will always generate a statistics. // TODO: unify the table stats generation. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index dbf479d21513..577112779eea 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql.catalyst.optimizer import scala.collection.mutable import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.{CatalystConf, SimpleCatalystConf} import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog} import org.apache.spark.sql.catalyst.expressions._ @@ -28,13 +27,14 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ /** * Abstract class all optimizers should inherit of, contains the standard batches (extending * Optimizers can override this. */ -abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf) +abstract class Optimizer(sessionCatalog: SessionCatalog, conf: SQLConf) extends RuleExecutor[LogicalPlan] { protected val fixedPoint = FixedPoint(conf.optimizerMaxIterations) @@ -160,8 +160,8 @@ class SimpleTestOptimizer extends Optimizer( new SessionCatalog( new InMemoryCatalog, EmptyFunctionRegistry, - new SimpleCatalystConf(caseSensitiveAnalysis = true)), - new SimpleCatalystConf(caseSensitiveAnalysis = true)) + new SQLConf().copy(SQLConf.CASE_SENSITIVE -> true)), + new SQLConf().copy(SQLConf.CASE_SENSITIVE -> true)) /** * Remove redundant aliases from a query plan. A redundant alias is an alias that does not change @@ -270,7 +270,7 @@ object RemoveRedundantProject extends Rule[LogicalPlan] { /** * Pushes down [[LocalLimit]] beneath UNION ALL and beneath the streamed inputs of outer joins. */ -case class LimitPushDown(conf: CatalystConf) extends Rule[LogicalPlan] { +case class LimitPushDown(conf: SQLConf) extends Rule[LogicalPlan] { private def stripGlobalLimitIfPresent(plan: LogicalPlan): LogicalPlan = { plan match { @@ -617,7 +617,7 @@ object CollapseWindow extends Rule[LogicalPlan] { * Note: While this optimization is applicable to all types of join, it primarily benefits Inner and * LeftSemi joins. */ -case class InferFiltersFromConstraints(conf: CatalystConf) +case class InferFiltersFromConstraints(conf: SQLConf) extends Rule[LogicalPlan] with PredicateHelper { def apply(plan: LogicalPlan): LogicalPlan = if (conf.constraintPropagationEnabled) { inferFilters(plan) @@ -715,7 +715,7 @@ object EliminateSorts extends Rule[LogicalPlan] { * 2) by substituting a dummy empty relation when the filter will always evaluate to `false`. * 3) by eliminating the always-true conditions given the constraints on the child's output. */ -case class PruneFilters(conf: CatalystConf) extends Rule[LogicalPlan] with PredicateHelper { +case class PruneFilters(conf: SQLConf) extends Rule[LogicalPlan] with PredicateHelper { def apply(plan: LogicalPlan): LogicalPlan = plan transform { // If the filter condition always evaluate to true, remove the filter. case Filter(Literal(true, BooleanType), child) => child @@ -1057,7 +1057,7 @@ object CombineLimits extends Rule[LogicalPlan] { * the join between R and S is not a cartesian product and therefore should be allowed. * The predicate R.r = S.s is not recognized as a join condition until the ReorderJoin rule. */ -case class CheckCartesianProducts(conf: CatalystConf) +case class CheckCartesianProducts(conf: SQLConf) extends Rule[LogicalPlan] with PredicateHelper { /** * Check if a join is a cartesian product. Returns true if @@ -1092,7 +1092,7 @@ case class CheckCartesianProducts(conf: CatalystConf) * This uses the same rules for increasing the precision and scale of the output as * [[org.apache.spark.sql.catalyst.analysis.DecimalPrecision]]. */ -case class DecimalAggregates(conf: CatalystConf) extends Rule[LogicalPlan] { +case class DecimalAggregates(conf: SQLConf) extends Rule[LogicalPlan] { import Decimal.MAX_LONG_DIGITS /** Maximum number of decimal digits representable precisely in a Double */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 33039127f16c..8445ee06bd89 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.catalyst.optimizer import scala.collection.immutable.HashSet -import org.apache.spark.sql.catalyst.CatalystConf import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ @@ -27,6 +26,7 @@ import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLite import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ /* @@ -115,7 +115,7 @@ object ReorderAssociativeOperator extends Rule[LogicalPlan] { * 2. Replaces [[In (value, seq[Literal])]] with optimized version * [[InSet (value, HashSet[Literal])]] which is much faster. */ -case class OptimizeIn(conf: CatalystConf) extends Rule[LogicalPlan] { +case class OptimizeIn(conf: SQLConf) extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case q: LogicalPlan => q transformExpressionsDown { case expr @ In(v, list) if expr.inSetConvertible => @@ -346,7 +346,7 @@ object LikeSimplification extends Rule[LogicalPlan] { * equivalent [[Literal]] values. This rule is more specific with * Null value propagation from bottom to top of the expression tree. */ -case class NullPropagation(conf: CatalystConf) extends Rule[LogicalPlan] { +case class NullPropagation(conf: SQLConf) extends Rule[LogicalPlan] { private def isNullLiteral(e: Expression): Boolean = e match { case Literal(null, _) => true case _ => false @@ -482,7 +482,7 @@ object FoldablePropagation extends Rule[LogicalPlan] { /** * Optimizes expressions by replacing according to CodeGen configuration. */ -case class OptimizeCodegen(conf: CatalystConf) extends Rule[LogicalPlan] { +case class OptimizeCodegen(conf: SQLConf) extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { case e: CaseWhen if canCodegen(e) => e.toCodegen() } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala index 5f7316566b3b..250dd07a16eb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.catalyst.optimizer import scala.annotation.tailrec -import org.apache.spark.sql.catalyst.CatalystConf import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning.{ExtractFiltersAndInnerJoins, PhysicalOperation} import org.apache.spark.sql.catalyst.plans._ @@ -440,7 +439,7 @@ case class ReorderJoin(conf: SQLConf) extends Rule[LogicalPlan] with PredicateHe * * This rule should be executed before pushing down the Filter */ -case class EliminateOuterJoin(conf: CatalystConf) extends Rule[LogicalPlan] with PredicateHelper { +case class EliminateOuterJoin(conf: SQLConf) extends Rule[LogicalPlan] with PredicateHelper { /** * Returns whether the expression returns null or false when all inputs are nulls. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/package.scala index 4af56afebb76..f9c88d496e89 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/package.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql -import org.apache.spark.sql.internal.SQLConf - /** * Catalyst is a library for manipulating relational query plans. All classes in catalyst are * considered an internal API to Spark SQL and are subject to change between minor releases. @@ -30,10 +28,4 @@ package object catalyst { * 2.10.* builds. See SI-6240 for more details. */ protected[sql] object ScalaReflectionLock - - /** - * This class is only here to minimize the change for ticket SPARK-19944 - * (moves SQLConf from sql/core to sql/catalyst). This class should eventually be removed. - */ - type CatalystConf = SQLConf } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala index 1faabcfcb73b..b7177c4a2c4e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala @@ -18,9 +18,10 @@ package org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.{CatalystConf, CatalystTypeConverters, InternalRow} +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.expressions.{Attribute, Literal} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{StructField, StructType} object LocalRelation { @@ -74,7 +75,7 @@ case class LocalRelation(output: Seq[Attribute], data: Seq[InternalRow] = Nil) } } - override def computeStats(conf: CatalystConf): Statistics = + override def computeStats(conf: SQLConf): Statistics = Statistics(sizeInBytes = output.map(n => BigInt(n.dataType.defaultSize)).sum * data.length) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index f71a976bd7a2..036b6256684c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -19,11 +19,11 @@ package org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.CatalystConf import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.trees.CurrentOrigin +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.StructType @@ -90,7 +90,7 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { * first time. If the configuration changes, the cache can be invalidated by calling * [[invalidateStatsCache()]]. */ - final def stats(conf: CatalystConf): Statistics = statsCache.getOrElse { + final def stats(conf: SQLConf): Statistics = statsCache.getOrElse { statsCache = Some(computeStats(conf)) statsCache.get } @@ -108,7 +108,7 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { * * [[LeafNode]]s must override this. */ - protected def computeStats(conf: CatalystConf): Statistics = { + protected def computeStats(conf: SQLConf): Statistics = { if (children.isEmpty) { throw new UnsupportedOperationException(s"LeafNode $nodeName must implement statistics.") } @@ -335,7 +335,7 @@ abstract class UnaryNode extends LogicalPlan { override protected def validConstraints: Set[Expression] = child.constraints - override def computeStats(conf: CatalystConf): Statistics = { + override def computeStats(conf: SQLConf): Statistics = { // There should be some overhead in Row object, the size should not be zero when there is // no columns, this help to prevent divide-by-zero error. val childRowSize = child.output.map(_.dataType.defaultSize).sum + 8 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 19db42c80895..c91de08ca5ef 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -17,13 +17,13 @@ package org.apache.spark.sql.catalyst.plans.logical -import org.apache.spark.sql.catalyst.{CatalystConf, TableIdentifier} import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation -import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogTypes} +import org.apache.spark.sql.catalyst.catalog.CatalogTable import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.statsEstimation._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -64,7 +64,7 @@ case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extend override def validConstraints: Set[Expression] = child.constraints.union(getAliasedConstraints(projectList)) - override def computeStats(conf: CatalystConf): Statistics = { + override def computeStats(conf: SQLConf): Statistics = { if (conf.cboEnabled) { ProjectEstimation.estimate(conf, this).getOrElse(super.computeStats(conf)) } else { @@ -138,7 +138,7 @@ case class Filter(condition: Expression, child: LogicalPlan) child.constraints.union(predicates.toSet) } - override def computeStats(conf: CatalystConf): Statistics = { + override def computeStats(conf: SQLConf): Statistics = { if (conf.cboEnabled) { FilterEstimation(this, conf).estimate.getOrElse(super.computeStats(conf)) } else { @@ -191,7 +191,7 @@ case class Intersect(left: LogicalPlan, right: LogicalPlan) extends SetOperation } } - override def computeStats(conf: CatalystConf): Statistics = { + override def computeStats(conf: SQLConf): Statistics = { val leftSize = left.stats(conf).sizeInBytes val rightSize = right.stats(conf).sizeInBytes val sizeInBytes = if (leftSize < rightSize) leftSize else rightSize @@ -208,7 +208,7 @@ case class Except(left: LogicalPlan, right: LogicalPlan) extends SetOperation(le override protected def validConstraints: Set[Expression] = leftConstraints - override def computeStats(conf: CatalystConf): Statistics = { + override def computeStats(conf: SQLConf): Statistics = { left.stats(conf).copy() } } @@ -247,7 +247,7 @@ case class Union(children: Seq[LogicalPlan]) extends LogicalPlan { children.length > 1 && childrenResolved && allChildrenCompatible } - override def computeStats(conf: CatalystConf): Statistics = { + override def computeStats(conf: SQLConf): Statistics = { val sizeInBytes = children.map(_.stats(conf).sizeInBytes).sum Statistics(sizeInBytes = sizeInBytes) } @@ -356,7 +356,7 @@ case class Join( case _ => resolvedExceptNatural } - override def computeStats(conf: CatalystConf): Statistics = { + override def computeStats(conf: SQLConf): Statistics = { def simpleEstimation: Statistics = joinType match { case LeftAnti | LeftSemi => // LeftSemi and LeftAnti won't ever be bigger than left @@ -382,7 +382,7 @@ case class BroadcastHint(child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output // set isBroadcastable to true so the child will be broadcasted - override def computeStats(conf: CatalystConf): Statistics = + override def computeStats(conf: SQLConf): Statistics = child.stats(conf).copy(isBroadcastable = true) } @@ -538,7 +538,7 @@ case class Range( override def newInstance(): Range = copy(output = output.map(_.newInstance())) - override def computeStats(conf: CatalystConf): Statistics = { + override def computeStats(conf: SQLConf): Statistics = { val sizeInBytes = LongType.defaultSize * numElements Statistics( sizeInBytes = sizeInBytes ) } @@ -571,7 +571,7 @@ case class Aggregate( child.constraints.union(getAliasedConstraints(nonAgg)) } - override def computeStats(conf: CatalystConf): Statistics = { + override def computeStats(conf: SQLConf): Statistics = { def simpleEstimation: Statistics = { if (groupingExpressions.isEmpty) { Statistics( @@ -687,7 +687,7 @@ case class Expand( override def references: AttributeSet = AttributeSet(projections.flatten.flatMap(_.references)) - override def computeStats(conf: CatalystConf): Statistics = { + override def computeStats(conf: SQLConf): Statistics = { val sizeInBytes = super.computeStats(conf).sizeInBytes * projections.length Statistics(sizeInBytes = sizeInBytes) } @@ -758,7 +758,7 @@ case class GlobalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryN case _ => None } } - override def computeStats(conf: CatalystConf): Statistics = { + override def computeStats(conf: SQLConf): Statistics = { val limit = limitExpr.eval().asInstanceOf[Int] val childStats = child.stats(conf) val rowCount: BigInt = childStats.rowCount.map(_.min(limit)).getOrElse(limit) @@ -778,7 +778,7 @@ case class LocalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryNo case _ => None } } - override def computeStats(conf: CatalystConf): Statistics = { + override def computeStats(conf: SQLConf): Statistics = { val limit = limitExpr.eval().asInstanceOf[Int] val childStats = child.stats(conf) if (limit == 0) { @@ -827,7 +827,7 @@ case class Sample( override def output: Seq[Attribute] = child.output - override def computeStats(conf: CatalystConf): Statistics = { + override def computeStats(conf: SQLConf): Statistics = { val ratio = upperBound - lowerBound val childStats = child.stats(conf) var sizeInBytes = EstimationUtils.ceil(BigDecimal(childStats.sizeInBytes) * ratio) @@ -893,7 +893,7 @@ case class RepartitionByExpression( case object OneRowRelation extends LeafNode { override def maxRows: Option[Long] = Some(1) override def output: Seq[Attribute] = Nil - override def computeStats(conf: CatalystConf): Statistics = Statistics(sizeInBytes = 1) + override def computeStats(conf: SQLConf): Statistics = Statistics(sizeInBytes = 1) } /** A logical plan for `dropDuplicates`. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala index ce74554c1701..48b5fbb03ef1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala @@ -17,9 +17,9 @@ package org.apache.spark.sql.catalyst.plans.logical.statsEstimation -import org.apache.spark.sql.catalyst.CatalystConf import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Statistics} +import org.apache.spark.sql.internal.SQLConf object AggregateEstimation { @@ -29,7 +29,7 @@ object AggregateEstimation { * Estimate the number of output rows based on column stats of group-by columns, and propagate * column stats for aggregate expressions. */ - def estimate(conf: CatalystConf, agg: Aggregate): Option[Statistics] = { + def estimate(conf: SQLConf, agg: Aggregate): Option[Statistics] = { val childStats = agg.child.stats(conf) // Check if we have column stats for all group-by columns. val colStatsExist = agg.groupingExpressions.forall { e => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala index 4d18b28be866..5577233ffa6f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala @@ -19,16 +19,16 @@ package org.apache.spark.sql.catalyst.plans.logical.statsEstimation import scala.math.BigDecimal.RoundingMode -import org.apache.spark.sql.catalyst.CatalystConf import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap} import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LogicalPlan, Statistics} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DataType, StringType} object EstimationUtils { /** Check if each plan has rowCount in its statistics. */ - def rowCountsExist(conf: CatalystConf, plans: LogicalPlan*): Boolean = + def rowCountsExist(conf: SQLConf, plans: LogicalPlan*): Boolean = plans.forall(_.stats(conf).rowCount.isDefined) /** Check if each attribute has column stat in the corresponding statistics. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala index 03c76cd41d81..7bd8e6511232 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala @@ -22,14 +22,14 @@ import scala.collection.mutable import scala.math.BigDecimal.RoundingMode import org.apache.spark.internal.Logging -import org.apache.spark.sql.catalyst.CatalystConf import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral} import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Filter, LeafNode, Statistics} import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ -case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Logging { +case class FilterEstimation(plan: Filter, catalystConf: SQLConf) extends Logging { private val childStats = plan.child.stats(catalystConf) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala index 9782c0bb0a93..3245a73c8a2e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala @@ -21,12 +21,12 @@ import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import org.apache.spark.internal.Logging -import org.apache.spark.sql.catalyst.CatalystConf import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference, Expression} import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Join, Statistics} import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils._ +import org.apache.spark.sql.internal.SQLConf object JoinEstimation extends Logging { @@ -34,7 +34,7 @@ object JoinEstimation extends Logging { * Estimate statistics after join. Return `None` if the join type is not supported, or we don't * have enough statistics for estimation. */ - def estimate(conf: CatalystConf, join: Join): Option[Statistics] = { + def estimate(conf: SQLConf, join: Join): Option[Statistics] = { join.joinType match { case Inner | Cross | LeftOuter | RightOuter | FullOuter => InnerOuterEstimation(conf, join).doEstimate() @@ -47,7 +47,7 @@ object JoinEstimation extends Logging { } } -case class InnerOuterEstimation(conf: CatalystConf, join: Join) extends Logging { +case class InnerOuterEstimation(conf: SQLConf, join: Join) extends Logging { private val leftStats = join.left.stats(conf) private val rightStats = join.right.stats(conf) @@ -288,7 +288,7 @@ case class InnerOuterEstimation(conf: CatalystConf, join: Join) extends Logging } } -case class LeftSemiAntiEstimation(conf: CatalystConf, join: Join) { +case class LeftSemiAntiEstimation(conf: SQLConf, join: Join) { def doEstimate(): Option[Statistics] = { // TODO: It's error-prone to estimate cardinalities for LeftSemi and LeftAnti based on basic // column stats. Now we just propagate the statistics from left side. We should do more diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/ProjectEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/ProjectEstimation.scala index e9084ad8b859..d700cd3b20f7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/ProjectEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/ProjectEstimation.scala @@ -17,14 +17,14 @@ package org.apache.spark.sql.catalyst.plans.logical.statsEstimation -import org.apache.spark.sql.catalyst.CatalystConf import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeMap} import org.apache.spark.sql.catalyst.plans.logical.{Project, Statistics} +import org.apache.spark.sql.internal.SQLConf object ProjectEstimation { import EstimationUtils._ - def estimate(conf: CatalystConf, project: Project): Option[Statistics] = { + def estimate(conf: SQLConf, project: Project): Option[Statistics] = { if (rowCountsExist(conf, project.child)) { val childStats = project.child.stats(conf) val inputAttrStats = childStats.attributeStats diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 06dc0b41204f..5b5d547f8fe5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1151,4 +1151,13 @@ class SQLConf extends Serializable with Logging { } result } + + // For test only + private[spark] def copy(entries: (ConfigEntry[_], Any)*): SQLConf = { + val cloned = clone() + entries.foreach { + case (entry, value) => cloned.setConfString(entry.key, value.toString) + } + cloned + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala index 0f059b959146..1be25ec06c74 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala @@ -18,10 +18,10 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.SimpleCatalystConf import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.internal.SQLConf trait AnalysisTest extends PlanTest { @@ -29,7 +29,7 @@ trait AnalysisTest extends PlanTest { protected val caseInsensitiveAnalyzer = makeAnalyzer(caseSensitive = false) private def makeAnalyzer(caseSensitive: Boolean): Analyzer = { - val conf = new SimpleCatalystConf(caseSensitive) + val conf = new SQLConf().copy(SQLConf.CASE_SENSITIVE -> caseSensitive) val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf) catalog.createTempView("TaBlE", TestRelations.testRelation, overrideIfExists = true) catalog.createTempView("TaBlE2", TestRelations.testRelation2, overrideIfExists = true) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala index 6995faebfa86..8f43171f309a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.catalyst.analysis import org.scalatest.BeforeAndAfter -import org.apache.spark.sql.catalyst.SimpleCatalystConf import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/SubstituteUnresolvedOrdinalsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/SubstituteUnresolvedOrdinalsSuite.scala index 88f68ebadc72..2331346f325a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/SubstituteUnresolvedOrdinalsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/SubstituteUnresolvedOrdinalsSuite.scala @@ -21,7 +21,7 @@ import org.apache.spark.sql.catalyst.analysis.TestRelations.testRelation2 import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions.Literal -import org.apache.spark.sql.catalyst.SimpleCatalystConf +import org.apache.spark.sql.internal.SQLConf class SubstituteUnresolvedOrdinalsSuite extends AnalysisTest { private lazy val a = testRelation2.output(0) @@ -44,7 +44,7 @@ class SubstituteUnresolvedOrdinalsSuite extends AnalysisTest { // order by ordinal can be turned off by config comparePlans( - new SubstituteUnresolvedOrdinals(conf.copy(orderByOrdinal = false)).apply(plan), + new SubstituteUnresolvedOrdinals(conf.copy(SQLConf.ORDER_BY_ORDINAL -> false)).apply(plan), testRelation2.orderBy(Literal(1).asc, Literal(2).asc)) } @@ -60,7 +60,7 @@ class SubstituteUnresolvedOrdinalsSuite extends AnalysisTest { // group by ordinal can be turned off by config comparePlans( - new SubstituteUnresolvedOrdinals(conf.copy(groupByOrdinal = false)).apply(plan2), + new SubstituteUnresolvedOrdinals(conf.copy(SQLConf.GROUP_BY_ORDINAL -> false)).apply(plan2), testRelation2.groupBy(Literal(1), Literal(2))('a, 'b)) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala index 56bca73a8857..9ba846fb2527 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala @@ -18,12 +18,13 @@ package org.apache.spark.sql.catalyst.catalog import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.{FunctionIdentifier, SimpleCatalystConf, TableIdentifier} +import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{Range, SubqueryAlias, View} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ class InMemorySessionCatalogSuite extends SessionCatalogSuite { @@ -1382,7 +1383,7 @@ abstract class SessionCatalogSuite extends PlanTest { import org.apache.spark.sql.catalyst.dsl.plans._ Seq(true, false) foreach { caseSensitive => - val conf = SimpleCatalystConf(caseSensitive) + val conf = new SQLConf().copy(SQLConf.CASE_SENSITIVE -> caseSensitive) val catalog = new SessionCatalog(newBasicCatalog(), new SimpleFunctionRegistry, conf) try { val analyzer = new Analyzer(catalog, conf) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala index b45bd977cbba..e6132ab2e4d1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst.optimizer -import org.apache.spark.sql.catalyst.SimpleCatalystConf import org.apache.spark.sql.catalyst.analysis.{Analyzer, EmptyFunctionRegistry} import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog} import org.apache.spark.sql.catalyst.dsl.expressions._ @@ -26,9 +25,11 @@ import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.SQLConf.{CASE_SENSITIVE, GROUP_BY_ORDINAL} class AggregateOptimizeSuite extends PlanTest { - override val conf = SimpleCatalystConf(caseSensitiveAnalysis = false, groupByOrdinal = false) + override val conf = new SQLConf().copy(CASE_SENSITIVE -> false, GROUP_BY_ORDINAL -> false) val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf) val analyzer = new Analyzer(catalog, conf) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BinaryComparisonSimplificationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BinaryComparisonSimplificationSuite.scala index 2bfddb7bc2f3..b29e1cbd1494 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BinaryComparisonSimplificationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BinaryComparisonSimplificationSuite.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst.optimizer -import org.apache.spark.sql.catalyst.SimpleCatalystConf import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ @@ -30,7 +29,6 @@ import org.apache.spark.sql.catalyst.rules._ class BinaryComparisonSimplificationSuite extends PlanTest with PredicateHelper { object Optimize extends RuleExecutor[LogicalPlan] { - val conf = SimpleCatalystConf(caseSensitiveAnalysis = true) val batches = Batch("AnalysisNodes", Once, EliminateSubqueryAliases) :: diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala index 4d404f55aa57..935bff7cef2e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst.optimizer -import org.apache.spark.sql.catalyst.SimpleCatalystConf import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog} import org.apache.spark.sql.catalyst.dsl.expressions._ @@ -26,11 +25,11 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ +import org.apache.spark.sql.internal.SQLConf class BooleanSimplificationSuite extends PlanTest with PredicateHelper { object Optimize extends RuleExecutor[LogicalPlan] { - val conf = SimpleCatalystConf(caseSensitiveAnalysis = true) val batches = Batch("AnalysisNodes", Once, EliminateSubqueryAliases) :: @@ -139,7 +138,7 @@ class BooleanSimplificationSuite extends PlanTest with PredicateHelper { checkCondition(!(('a || 'b) && ('c || 'd)), (!'a && !'b) || (!'c && !'d)) } - private val caseInsensitiveConf = new SimpleCatalystConf(false) + private val caseInsensitiveConf = new SQLConf().copy(SQLConf.CASE_SENSITIVE -> false) private val caseInsensitiveAnalyzer = new Analyzer( new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, caseInsensitiveConf), caseInsensitiveConf) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala index 276b8055b08d..ac71887c16f9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst.optimizer -import org.apache.spark.sql.catalyst.SimpleCatalystConf import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.plans.logical._ @@ -33,7 +32,7 @@ class CombiningLimitsSuite extends PlanTest { Batch("Combine Limit", FixedPoint(10), CombineLimits) :: Batch("Constant Folding", FixedPoint(10), - NullPropagation(SimpleCatalystConf(caseSensitiveAnalysis = true)), + NullPropagation(conf), ConstantFolding, BooleanSimplification, SimplifyConditionals) :: Nil diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala index d9655bbcc2ce..25c592b9c1dd 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst.optimizer -import org.apache.spark.sql.catalyst.SimpleCatalystConf import org.apache.spark.sql.catalyst.analysis.{EliminateSubqueryAliases, UnresolvedExtractValue} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ @@ -34,7 +33,7 @@ class ConstantFoldingSuite extends PlanTest { Batch("AnalysisNodes", Once, EliminateSubqueryAliases) :: Batch("ConstantFolding", Once, - OptimizeIn(SimpleCatalystConf(true)), + OptimizeIn(conf), ConstantFolding, BooleanSimplification) :: Nil } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/DecimalAggregatesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/DecimalAggregatesSuite.scala index a491f4433370..cc4fb3a244a9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/DecimalAggregatesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/DecimalAggregatesSuite.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst.optimizer -import org.apache.spark.sql.catalyst.SimpleCatalystConf import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ @@ -30,7 +29,7 @@ class DecimalAggregatesSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { val batches = Batch("Decimal Optimizations", FixedPoint(100), - DecimalAggregates(SimpleCatalystConf(caseSensitiveAnalysis = true))) :: Nil + DecimalAggregates(conf)) :: Nil } val testRelation = LocalRelation('a.decimal(2, 1), 'b.decimal(12, 1)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala index c5f9cc185275..e318f36d7827 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst.optimizer -import org.apache.spark.sql.catalyst.SimpleCatalystConf import org.apache.spark.sql.catalyst.analysis.{Analyzer, EmptyFunctionRegistry} import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog} import org.apache.spark.sql.catalyst.dsl.expressions._ @@ -26,9 +25,11 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.SQLConf.{CASE_SENSITIVE, ORDER_BY_ORDINAL} class EliminateSortsSuite extends PlanTest { - override val conf = new SimpleCatalystConf(caseSensitiveAnalysis = true, orderByOrdinal = false) + override val conf = new SQLConf().copy(CASE_SENSITIVE -> true, ORDER_BY_ORDINAL -> false) val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf) val analyzer = new Analyzer(catalog, conf) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala index 98d8b897a916..c8fe37462726 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala @@ -17,13 +17,13 @@ package org.apache.spark.sql.catalyst.optimizer -import org.apache.spark.sql.catalyst.SimpleCatalystConf import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ +import org.apache.spark.sql.internal.SQLConf.CONSTRAINT_PROPAGATION_ENABLED class InferFiltersFromConstraintsSuite extends PlanTest { @@ -32,7 +32,7 @@ class InferFiltersFromConstraintsSuite extends PlanTest { Batch("InferAndPushDownFilters", FixedPoint(100), PushPredicateThroughJoin, PushDownPredicate, - InferFiltersFromConstraints(SimpleCatalystConf(caseSensitiveAnalysis = true)), + InferFiltersFromConstraints(conf), CombineFilters) :: Nil } @@ -41,8 +41,7 @@ class InferFiltersFromConstraintsSuite extends PlanTest { Batch("InferAndPushDownFilters", FixedPoint(100), PushPredicateThroughJoin, PushDownPredicate, - InferFiltersFromConstraints(SimpleCatalystConf(caseSensitiveAnalysis = true, - constraintPropagationEnabled = false)), + InferFiltersFromConstraints(conf.copy(CONSTRAINT_PROPAGATION_ENABLED -> false)), CombineFilters) :: Nil } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala index 61e81808147c..a43d78c7bd44 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala @@ -26,7 +26,6 @@ import org.apache.spark.sql.catalyst.planning.ExtractFiltersAndInnerJoins import org.apache.spark.sql.catalyst.plans.{Cross, Inner, InnerLike, PlanTest} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.RuleExecutor -import org.apache.spark.sql.catalyst.SimpleCatalystConf class JoinOptimizationSuite extends PlanTest { @@ -38,7 +37,7 @@ class JoinOptimizationSuite extends PlanTest { CombineFilters, PushDownPredicate, BooleanSimplification, - ReorderJoin(SimpleCatalystConf(true)), + ReorderJoin(conf), PushPredicateThroughJoin, ColumnPruning, CollapseProject) :: Nil diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala index d74008c1b302..1922eb30fdce 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst.optimizer -import org.apache.spark.sql.catalyst.SimpleCatalystConf import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap} @@ -25,12 +24,14 @@ import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest} import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.catalyst.statsEstimation.{StatsEstimationTestBase, StatsTestPlan} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.SQLConf.{CASE_SENSITIVE, CBO_ENABLED, JOIN_REORDER_ENABLED} class JoinReorderSuite extends PlanTest with StatsEstimationTestBase { - override val conf = SimpleCatalystConf( - caseSensitiveAnalysis = true, cboEnabled = true, joinReorderEnabled = true) + override val conf = new SQLConf().copy( + CASE_SENSITIVE -> true, CBO_ENABLED -> true, JOIN_REORDER_ENABLED -> true) object Optimize extends RuleExecutor[LogicalPlan] { val batches = diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala index 0f3ba6c89556..2885fd6841e9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst.optimizer -import org.apache.spark.sql.catalyst.SimpleCatalystConf import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeCodegenSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeCodegenSuite.scala index 4385b0e019f2..f3b65cc797ec 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeCodegenSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeCodegenSuite.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.SimpleCatalystConf import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.Literal._ import org.apache.spark.sql.catalyst.plans.PlanTest @@ -29,7 +28,7 @@ import org.apache.spark.sql.catalyst.rules._ class OptimizeCodegenSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { - val batches = Batch("OptimizeCodegen", Once, OptimizeCodegen(SimpleCatalystConf(true))) :: Nil + val batches = Batch("OptimizeCodegen", Once, OptimizeCodegen(conf)) :: Nil } protected def assertEquivalent(e1: Expression, e2: Expression): Unit = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala index 9daede1a5f95..d8937321ecb9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst.optimizer -import org.apache.spark.sql.catalyst.SimpleCatalystConf import org.apache.spark.sql.catalyst.analysis.{EliminateSubqueryAliases, UnresolvedAttribute} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ @@ -25,6 +24,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.{Filter, LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.internal.SQLConf.OPTIMIZER_INSET_CONVERSION_THRESHOLD import org.apache.spark.sql.types._ class OptimizeInSuite extends PlanTest { @@ -34,10 +34,10 @@ class OptimizeInSuite extends PlanTest { Batch("AnalysisNodes", Once, EliminateSubqueryAliases) :: Batch("ConstantFolding", FixedPoint(10), - NullPropagation(SimpleCatalystConf(caseSensitiveAnalysis = true)), + NullPropagation(conf), ConstantFolding, BooleanSimplification, - OptimizeIn(SimpleCatalystConf(caseSensitiveAnalysis = true))) :: Nil + OptimizeIn(conf)) :: Nil } val testRelation = LocalRelation('a.int, 'b.int, 'c.int) @@ -159,12 +159,11 @@ class OptimizeInSuite extends PlanTest { .where(In(UnresolvedAttribute("a"), Seq(Literal(1), Literal(2), Literal(3)))) .analyze - val notOptimizedPlan = OptimizeIn(SimpleCatalystConf(caseSensitiveAnalysis = true))(plan) + val notOptimizedPlan = OptimizeIn(conf)(plan) comparePlans(notOptimizedPlan, plan) // Reduce the threshold to turning into InSet. - val optimizedPlan = OptimizeIn(SimpleCatalystConf(caseSensitiveAnalysis = true, - optimizerInSetConversionThreshold = 2))(plan) + val optimizedPlan = OptimizeIn(conf.copy(OPTIMIZER_INSET_CONVERSION_THRESHOLD -> 2))(plan) optimizedPlan match { case Filter(cond, _) if cond.isInstanceOf[InSet] && cond.asInstanceOf[InSet].getHSet().size == 3 => diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OuterJoinEliminationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OuterJoinEliminationSuite.scala index cbabc1fa6d92..b7136703b754 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OuterJoinEliminationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OuterJoinEliminationSuite.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst.optimizer -import org.apache.spark.sql.catalyst.SimpleCatalystConf import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ @@ -25,6 +24,7 @@ import org.apache.spark.sql.catalyst.expressions.{Coalesce, IsNotNull} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ +import org.apache.spark.sql.internal.SQLConf.CONSTRAINT_PROPAGATION_ENABLED class OuterJoinEliminationSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { @@ -32,7 +32,7 @@ class OuterJoinEliminationSuite extends PlanTest { Batch("Subqueries", Once, EliminateSubqueryAliases) :: Batch("Outer Join Elimination", Once, - EliminateOuterJoin(SimpleCatalystConf(caseSensitiveAnalysis = true)), + EliminateOuterJoin(conf), PushPredicateThroughJoin) :: Nil } @@ -41,8 +41,7 @@ class OuterJoinEliminationSuite extends PlanTest { Batch("Subqueries", Once, EliminateSubqueryAliases) :: Batch("Outer Join Elimination", Once, - EliminateOuterJoin(SimpleCatalystConf(caseSensitiveAnalysis = true, - constraintPropagationEnabled = false)), + EliminateOuterJoin(conf.copy(CONSTRAINT_PROPAGATION_ENABLED -> false)), PushPredicateThroughJoin) :: Nil } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala index f771e3e9eba6..c261a6091d47 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.SimpleCatalystConf import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.plans._ @@ -34,7 +33,7 @@ class PropagateEmptyRelationSuite extends PlanTest { ReplaceExceptWithAntiJoin, ReplaceIntersectWithSemiJoin, PushDownPredicate, - PruneFilters(SimpleCatalystConf(caseSensitiveAnalysis = true)), + PruneFilters(conf), PropagateEmptyRelation) :: Nil } @@ -46,7 +45,7 @@ class PropagateEmptyRelationSuite extends PlanTest { ReplaceExceptWithAntiJoin, ReplaceIntersectWithSemiJoin, PushDownPredicate, - PruneFilters(SimpleCatalystConf(caseSensitiveAnalysis = true))) :: Nil + PruneFilters(conf)) :: Nil } val testRelation1 = LocalRelation.fromExternalRows(Seq('a.int), data = Seq(Row(1))) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PruneFiltersSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PruneFiltersSuite.scala index 20f7f69e86c0..741dd0cf428d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PruneFiltersSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PruneFiltersSuite.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst.optimizer -import org.apache.spark.sql.catalyst.SimpleCatalystConf import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ @@ -25,6 +24,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ +import org.apache.spark.sql.internal.SQLConf.CONSTRAINT_PROPAGATION_ENABLED class PruneFiltersSuite extends PlanTest { @@ -34,7 +34,7 @@ class PruneFiltersSuite extends PlanTest { EliminateSubqueryAliases) :: Batch("Filter Pushdown and Pruning", Once, CombineFilters, - PruneFilters(SimpleCatalystConf(caseSensitiveAnalysis = true)), + PruneFilters(conf), PushDownPredicate, PushPredicateThroughJoin) :: Nil } @@ -45,8 +45,7 @@ class PruneFiltersSuite extends PlanTest { EliminateSubqueryAliases) :: Batch("Filter Pushdown and Pruning", Once, CombineFilters, - PruneFilters(SimpleCatalystConf(caseSensitiveAnalysis = true, - constraintPropagationEnabled = false)), + PruneFilters(conf.copy(CONSTRAINT_PROPAGATION_ENABLED -> false)), PushDownPredicate, PushPredicateThroughJoin) :: Nil } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala index 350a1c26fd1e..8cb939e010c6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala @@ -16,19 +16,20 @@ */ package org.apache.spark.sql.catalyst.optimizer -import org.apache.spark.sql.catalyst.SimpleCatalystConf import org.apache.spark.sql.catalyst.analysis.{Analyzer, EmptyFunctionRegistry} import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.expressions.{If, Literal} -import org.apache.spark.sql.catalyst.expressions.aggregate.{CollectSet, Count} +import org.apache.spark.sql.catalyst.expressions.Literal +import org.apache.spark.sql.catalyst.expressions.aggregate.CollectSet import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Expand, LocalRelation, LogicalPlan} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.SQLConf.{CASE_SENSITIVE, GROUP_BY_ORDINAL} import org.apache.spark.sql.types.{IntegerType, StringType} class RewriteDistinctAggregatesSuite extends PlanTest { - override val conf = SimpleCatalystConf(caseSensitiveAnalysis = false, groupByOrdinal = false) + override val conf = new SQLConf().copy(CASE_SENSITIVE -> false, GROUP_BY_ORDINAL -> false) val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf) val analyzer = new Analyzer(catalog, conf) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala index ca4976f0d6db..756e0f35b217 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst.optimizer -import org.apache.spark.sql.catalyst.SimpleCatalystConf import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ @@ -35,7 +34,7 @@ class SetOperationSuite extends PlanTest { CombineUnions, PushProjectionThroughUnion, PushDownPredicate, - PruneFilters(SimpleCatalystConf(caseSensitiveAnalysis = true))) :: Nil + PruneFilters(conf)) :: Nil } val testRelation = LocalRelation('a.int, 'b.int, 'c.int) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinReorderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinReorderSuite.scala index 93fdd98d1ac9..003ce49eaf8e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinReorderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinReorderSuite.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst.optimizer -import org.apache.spark.sql.catalyst.SimpleCatalystConf import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap} @@ -25,12 +24,12 @@ import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest} import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.catalyst.statsEstimation.{StatsEstimationTestBase, StatsTestPlan} - +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.SQLConf.{CASE_SENSITIVE, STARSCHEMA_DETECTION} class StarJoinReorderSuite extends PlanTest with StatsEstimationTestBase { - override val conf = SimpleCatalystConf( - caseSensitiveAnalysis = true, starSchemaDetection = true) + override val conf = new SQLConf().copy(CASE_SENSITIVE -> true, STARSCHEMA_DETECTION -> true) object Optimize extends RuleExecutor[LogicalPlan] { val batches = diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala index c73dfaf3f8fe..f44428c3512a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala @@ -18,18 +18,18 @@ package org.apache.spark.sql.catalyst.plans import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.SimpleCatalystConf import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.internal.SQLConf /** * Provides helper methods for comparing plans. */ abstract class PlanTest extends SparkFunSuite with PredicateHelper { - protected val conf = SimpleCatalystConf(caseSensitiveAnalysis = true) + protected val conf = new SQLConf().copy(SQLConf.CASE_SENSITIVE -> true) /** * Since attribute references are given globally unique ids during analysis, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggregateEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggregateEstimationSuite.scala index c0b9515ca7cd..38483a298cef 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggregateEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggregateEstimationSuite.scala @@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeMap import org.apache.spark.sql.catalyst.expressions.aggregate.Count import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils._ +import org.apache.spark.sql.internal.SQLConf class AggregateEstimationSuite extends StatsEstimationTestBase { @@ -101,13 +102,13 @@ class AggregateEstimationSuite extends StatsEstimationTestBase { val noGroupAgg = Aggregate(groupingExpressions = Nil, aggregateExpressions = Seq(Alias(Count(Literal(1)), "cnt")()), child) - assert(noGroupAgg.stats(conf.copy(cboEnabled = false)) == + assert(noGroupAgg.stats(conf.copy(SQLConf.CBO_ENABLED -> false)) == // overhead + count result size Statistics(sizeInBytes = 8 + 8, rowCount = Some(1))) val hasGroupAgg = Aggregate(groupingExpressions = attributes, aggregateExpressions = attributes :+ Alias(Count(Literal(1)), "cnt")(), child) - assert(hasGroupAgg.stats(conf.copy(cboEnabled = false)) == + assert(hasGroupAgg.stats(conf.copy(SQLConf.CBO_ENABLED -> false)) == // From UnaryNode.computeStats, childSize * outputRowSize / childRowSize Statistics(sizeInBytes = 48 * (8 + 4 + 8) / (8 + 4))) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala index 0d92c1e35565..b06871f96f0d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala @@ -17,9 +17,9 @@ package org.apache.spark.sql.catalyst.statsEstimation -import org.apache.spark.sql.catalyst.CatalystConf import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference, Literal} import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.IntegerType @@ -116,10 +116,10 @@ class BasicStatsEstimationSuite extends StatsEstimationTestBase { expectedStatsCboOff: Statistics): Unit = { // Invalidate statistics plan.invalidateStatsCache() - assert(plan.stats(conf.copy(cboEnabled = true)) == expectedStatsCboOn) + assert(plan.stats(conf.copy(SQLConf.CBO_ENABLED -> true)) == expectedStatsCboOn) plan.invalidateStatsCache() - assert(plan.stats(conf.copy(cboEnabled = false)) == expectedStatsCboOff) + assert(plan.stats(conf.copy(SQLConf.CBO_ENABLED -> false)) == expectedStatsCboOff) } /** Check estimated stats when it's the same whether cbo is turned on or off. */ @@ -136,6 +136,6 @@ private case class DummyLogicalPlan( cboStats: Statistics) extends LogicalPlan { override def output: Seq[Attribute] = Nil override def children: Seq[LogicalPlan] = Nil - override def computeStats(conf: CatalystConf): Statistics = + override def computeStats(conf: SQLConf): Statistics = if (conf.cboEnabled) cboStats else defaultStats } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala index 9b2b8dbe1bf4..263f4e18803d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala @@ -18,16 +18,17 @@ package org.apache.spark.sql.catalyst.statsEstimation import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.{CatalystConf, SimpleCatalystConf} import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference} import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LeafNode, LogicalPlan, Statistics} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.SQLConf.{CASE_SENSITIVE, CBO_ENABLED} import org.apache.spark.sql.types.{IntegerType, StringType} trait StatsEstimationTestBase extends SparkFunSuite { /** Enable stats estimation based on CBO. */ - protected val conf = SimpleCatalystConf(caseSensitiveAnalysis = true, cboEnabled = true) + protected val conf = new SQLConf().copy(CASE_SENSITIVE -> true, CBO_ENABLED -> true) def getColSize(attribute: Attribute, colStat: ColumnStat): Long = attribute.dataType match { // For UTF8String: base + offset + numBytes @@ -54,7 +55,7 @@ case class StatsTestPlan( attributeStats: AttributeMap[ColumnStat], size: Option[BigInt] = None) extends LeafNode { override def output: Seq[Attribute] = outputList - override def computeStats(conf: CatalystConf): Statistics = Statistics( + override def computeStats(conf: SQLConf): Statistics = Statistics( // If sizeInBytes is useless in testing, we just use a fake value sizeInBytes = size.getOrElse(Int.MaxValue), rowCount = Some(rowCount), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index 49336f424822..2827b8ac0033 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -19,12 +19,13 @@ package org.apache.spark.sql.execution import org.apache.spark.rdd.RDD import org.apache.spark.sql.{Encoder, Row, SparkSession} -import org.apache.spark.sql.catalyst.{CatalystConf, CatalystTypeConverters, InternalRow} +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartitioning} import org.apache.spark.sql.execution.metric.SQLMetrics +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.DataType import org.apache.spark.util.Utils @@ -95,7 +96,7 @@ case class ExternalRDD[T]( override protected def stringArgs: Iterator[Any] = Iterator(output) - @transient override def computeStats(conf: CatalystConf): Statistics = Statistics( + @transient override def computeStats(conf: SQLConf): Statistics = Statistics( // TODO: Instead of returning a default value here, find a way to return a meaningful size // estimate for RDDs. See PR 1238 for more discussions. sizeInBytes = BigInt(session.sessionState.conf.defaultSizeInBytes) @@ -170,7 +171,7 @@ case class LogicalRDD( override protected def stringArgs: Iterator[Any] = Iterator(output) - @transient override def computeStats(conf: CatalystConf): Statistics = Statistics( + @transient override def computeStats(conf: SQLConf): Statistics = Statistics( // TODO: Instead of returning a default value here, find a way to return a meaningful size // estimate for RDDs. See PR 1238 for more discussions. sizeInBytes = BigInt(session.sessionState.conf.defaultSizeInBytes) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala index 36037ac00372..0a9f3e799990 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala @@ -21,12 +21,13 @@ import org.apache.commons.lang3.StringUtils import org.apache.spark.network.util.JavaUtils import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.{CatalystConf, InternalRow} +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical.Statistics import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.storage.StorageLevel import org.apache.spark.util.LongAccumulator @@ -69,7 +70,7 @@ case class InMemoryRelation( @transient val partitionStatistics = new PartitionStatistics(output) - override def computeStats(conf: CatalystConf): Statistics = { + override def computeStats(conf: SQLConf): Statistics = { if (batchStats.value == 0L) { // Underlying columnar RDD hasn't been materialized, no useful statistics information // available, return the default statistics. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index c350d8bcbae9..e5c7c383d708 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -21,12 +21,10 @@ import java.util.concurrent.Callable import scala.collection.mutable.ArrayBuffer -import org.apache.hadoop.fs.Path - import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.{CatalystConf, CatalystTypeConverters, InternalRow, QualifiedTableName, TableIdentifier} +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, QualifiedTableName, TableIdentifier} import org.apache.spark.sql.catalyst.CatalystTypeConverters.convertToScala import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.catalog.{CatalogRelation, CatalogUtils} @@ -38,7 +36,7 @@ import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, UnknownPa import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.{RowDataSourceScanExec, SparkPlan} import org.apache.spark.sql.execution.command._ -import org.apache.spark.sql.internal.StaticSQLConf +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -50,7 +48,7 @@ import org.apache.spark.unsafe.types.UTF8String * Note that, this rule must be run after `PreprocessTableCreation` and * `PreprocessTableInsertion`. */ -case class DataSourceAnalysis(conf: CatalystConf) extends Rule[LogicalPlan] { +case class DataSourceAnalysis(conf: SQLConf) extends Rule[LogicalPlan] { def resolver: Resolver = conf.resolver diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala index 04a764bee2ef..3b14b794fd08 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala @@ -16,11 +16,11 @@ */ package org.apache.spark.sql.execution.datasources -import org.apache.spark.sql.catalyst.CatalystConf import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.catalog.CatalogTable import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference} import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.util.Utils @@ -73,7 +73,7 @@ case class LogicalRelation( // expId can be different but the relation is still the same. override lazy val cleanArgs: Seq[Any] = Seq(relation) - @transient override def computeStats(conf: CatalystConf): Statistics = { + @transient override def computeStats(conf: SQLConf): Statistics = { catalogTable.flatMap(_.stats.map(_.toPlanStats(output))).getOrElse( Statistics(sizeInBytes = relation.sizeInBytes)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala index 6d34d51d31c1..971ce5afb177 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -25,11 +25,11 @@ import scala.util.control.NonFatal import org.apache.spark.internal.Logging import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.CatalystConf import org.apache.spark.sql.catalyst.encoders.encoderFor import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics} import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils @@ -230,6 +230,6 @@ case class MemoryPlan(sink: MemorySink, output: Seq[Attribute]) extends LeafNode private val sizePerRow = sink.schema.toAttributes.map(_.dataType.defaultSize).sum - override def computeStats(conf: CatalystConf): Statistics = + override def computeStats(conf: SQLConf): Statistics = Statistics(sizePerRow * sink.allData.size) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceAnalysisSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceAnalysisSuite.scala index 448adcf11d65..b16c9f8fc96b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceAnalysisSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceAnalysisSuite.scala @@ -21,10 +21,10 @@ import org.scalatest.BeforeAndAfterAll import org.apache.spark.SparkFunSuite import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.SimpleCatalystConf import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Cast, Expression, Literal} import org.apache.spark.sql.execution.datasources.DataSourceAnalysis +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{IntegerType, StructType} class DataSourceAnalysisSuite extends SparkFunSuite with BeforeAndAfterAll { @@ -49,7 +49,7 @@ class DataSourceAnalysisSuite extends SparkFunSuite with BeforeAndAfterAll { } Seq(true, false).foreach { caseSensitive => - val rule = DataSourceAnalysis(SimpleCatalystConf(caseSensitive)) + val rule = DataSourceAnalysis(new SQLConf().copy(SQLConf.CASE_SENSITIVE -> caseSensitive)) test( s"convertStaticPartitions only handle INSERT having at least static partitions " + s"(caseSensitive: $caseSensitive)") { From 295747e59739ee8a697ac3eba485d3439e4a04c3 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 4 Apr 2017 16:38:32 -0700 Subject: [PATCH 208/512] [SPARK-19716][SQL] support by-name resolution for struct type elements in array ## What changes were proposed in this pull request? Previously when we construct deserializer expression for array type, we will first cast the corresponding field to expected array type and then apply `MapObjects`. However, by doing that, we lose the opportunity to do by-name resolution for struct type inside array type. In this PR, I introduce a `UnresolvedMapObjects` to hold the lambda function and the input array expression. Then during analysis, after the input array expression is resolved, we get the actual array element type and apply by-name resolution. Then we don't need to add `Cast` for array type when constructing the deserializer expression, as the element type is determined later at analyzer. ## How was this patch tested? new regression test Author: Wenchen Fan Closes #17398 from cloud-fan/dataset. --- .../spark/sql/catalyst/ScalaReflection.scala | 66 +++++++++++-------- .../sql/catalyst/analysis/Analyzer.scala | 19 +++++- .../expressions/complexTypeExtractors.scala | 2 +- .../expressions/objects/objects.scala | 32 +++++++-- .../encoders/EncoderResolutionSuite.scala | 52 +++++++++++++++ .../sql/expressions/ReduceAggregator.scala | 2 +- .../org/apache/spark/sql/DatasetSuite.scala | 9 +++ 7 files changed, 141 insertions(+), 41 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index da37eb00dcd9..206ae2f0e5eb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -92,7 +92,7 @@ object ScalaReflection extends ScalaReflection { * Array[T]. Special handling is performed for primitive types to map them back to their raw * JVM form instead of the Scala Array that handles auto boxing. */ - private def arrayClassFor(tpe: `Type`): DataType = ScalaReflectionLock.synchronized { + private def arrayClassFor(tpe: `Type`): ObjectType = ScalaReflectionLock.synchronized { val cls = tpe match { case t if t <:< definitions.IntTpe => classOf[Array[Int]] case t if t <:< definitions.LongTpe => classOf[Array[Long]] @@ -178,15 +178,17 @@ object ScalaReflection extends ScalaReflection { * is [a: int, b: long], then we will hit runtime error and say that we can't construct class * `Data` with int and long, because we lost the information that `b` should be a string. * - * This method help us "remember" the required data type by adding a `UpCast`. Note that we - * don't need to cast struct type because there must be `UnresolvedExtractValue` or - * `GetStructField` wrapping it, thus we only need to handle leaf type. + * This method help us "remember" the required data type by adding a `UpCast`. Note that we + * only need to do this for leaf nodes. */ def upCastToExpectedType( expr: Expression, expected: DataType, walkedTypePath: Seq[String]): Expression = expected match { case _: StructType => expr + case _: ArrayType => expr + // TODO: ideally we should also skip MapType, but nested StructType inside MapType is rare and + // it's not trivial to support by-name resolution for StructType inside MapType. case _ => UpCast(expr, expected, walkedTypePath) } @@ -265,42 +267,48 @@ object ScalaReflection extends ScalaReflection { case t if t <:< localTypeOf[Array[_]] => val TypeRef(_, _, Seq(elementType)) = t + val Schema(_, elementNullable) = schemaFor(elementType) + val className = getClassNameFromType(elementType) + val newTypePath = s"""- array element class: "$className"""" +: walkedTypePath - // TODO: add runtime null check for primitive array - val primitiveMethod = elementType match { - case t if t <:< definitions.IntTpe => Some("toIntArray") - case t if t <:< definitions.LongTpe => Some("toLongArray") - case t if t <:< definitions.DoubleTpe => Some("toDoubleArray") - case t if t <:< definitions.FloatTpe => Some("toFloatArray") - case t if t <:< definitions.ShortTpe => Some("toShortArray") - case t if t <:< definitions.ByteTpe => Some("toByteArray") - case t if t <:< definitions.BooleanTpe => Some("toBooleanArray") - case _ => None + val mapFunction: Expression => Expression = p => { + val converter = deserializerFor(elementType, Some(p), newTypePath) + if (elementNullable) { + converter + } else { + AssertNotNull(converter, newTypePath) + } } - primitiveMethod.map { method => - Invoke(getPath, method, arrayClassFor(elementType)) - }.getOrElse { - val className = getClassNameFromType(elementType) - val newTypePath = s"""- array element class: "$className"""" +: walkedTypePath - Invoke( - MapObjects( - p => deserializerFor(elementType, Some(p), newTypePath), - getPath, - schemaFor(elementType).dataType), - "array", - arrayClassFor(elementType)) + val arrayData = UnresolvedMapObjects(mapFunction, getPath) + val arrayCls = arrayClassFor(elementType) + + if (elementNullable) { + Invoke(arrayData, "array", arrayCls) + } else { + val primitiveMethod = elementType match { + case t if t <:< definitions.IntTpe => "toIntArray" + case t if t <:< definitions.LongTpe => "toLongArray" + case t if t <:< definitions.DoubleTpe => "toDoubleArray" + case t if t <:< definitions.FloatTpe => "toFloatArray" + case t if t <:< definitions.ShortTpe => "toShortArray" + case t if t <:< definitions.ByteTpe => "toByteArray" + case t if t <:< definitions.BooleanTpe => "toBooleanArray" + case other => throw new IllegalStateException("expect primitive array element type " + + "but got " + other) + } + Invoke(arrayData, primitiveMethod, arrayCls) } case t if t <:< localTypeOf[Seq[_]] => val TypeRef(_, _, Seq(elementType)) = t - val Schema(dataType, nullable) = schemaFor(elementType) + val Schema(_, elementNullable) = schemaFor(elementType) val className = getClassNameFromType(elementType) val newTypePath = s"""- array element class: "$className"""" +: walkedTypePath val mapFunction: Expression => Expression = p => { val converter = deserializerFor(elementType, Some(p), newTypePath) - if (nullable) { + if (elementNullable) { converter } else { AssertNotNull(converter, newTypePath) @@ -312,7 +320,7 @@ object ScalaReflection extends ScalaReflection { case NoSymbol => classOf[Seq[_]] case _ => mirror.runtimeClass(t.typeSymbol.asClass) } - MapObjects(mapFunction, getPath, dataType, Some(cls)) + UnresolvedMapObjects(mapFunction, getPath, Some(cls)) case t if t <:< localTypeOf[Map[_, _]] => // TODO: add walked type path for map diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 2d53d2424a34..c698ca6a8347 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.encoders.OuterScopes import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.catalyst.expressions.objects.NewInstance +import org.apache.spark.sql.catalyst.expressions.objects.{MapObjects, NewInstance, UnresolvedMapObjects} import org.apache.spark.sql.catalyst.expressions.SubExprUtils._ import org.apache.spark.sql.catalyst.optimizer.BooleanSimplification import org.apache.spark.sql.catalyst.plans._ @@ -2227,8 +2227,21 @@ class Analyzer( validateTopLevelTupleFields(deserializer, inputs) val resolved = resolveExpression( deserializer, LocalRelation(inputs), throws = true) - validateNestedTupleFields(resolved) - resolved + val result = resolved transformDown { + case UnresolvedMapObjects(func, inputData, cls) if inputData.resolved => + inputData.dataType match { + case ArrayType(et, _) => + val expr = MapObjects(func, inputData, et, cls) transformUp { + case UnresolvedExtractValue(child, fieldName) if child.resolved => + ExtractValue(child, fieldName, resolver) + } + expr + case other => + throw new AnalysisException("need an array field but got " + other.simpleString) + } + } + validateNestedTupleFields(result) + result } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index de1594d119e1..ef88cfb543eb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -68,7 +68,7 @@ object ExtractValue { case StructType(_) => s"Field name should be String Literal, but it's $extraction" case other => - s"Can't extract value from $child" + s"Can't extract value from $child: need struct type but got ${other.simpleString}" } throw new AnalysisException(errorMsg) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index bb584f7d087e..00e2ac91e67c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -448,6 +448,17 @@ object MapObjects { } } +case class UnresolvedMapObjects( + function: Expression => Expression, + child: Expression, + customCollectionCls: Option[Class[_]] = None) extends UnaryExpression with Unevaluable { + override lazy val resolved = false + + override def dataType: DataType = customCollectionCls.map(ObjectType.apply).getOrElse { + throw new UnsupportedOperationException("not resolved") + } +} + /** * Applies the given expression to every element of a collection of items, returning the result * as an ArrayType or ObjectType. This is similar to a typical map operation, but where the lambda @@ -581,17 +592,24 @@ case class MapObjects private( // collection val collObjectName = s"${cls.getName}$$.MODULE$$" val getBuilderVar = s"$collObjectName.newBuilder()" - - (s"""${classOf[Builder[_, _]].getName} $builderValue = $getBuilderVar; - $builderValue.sizeHint($dataLength);""", + ( + s""" + ${classOf[Builder[_, _]].getName} $builderValue = $getBuilderVar; + $builderValue.sizeHint($dataLength); + """, genValue => s"$builderValue.$$plus$$eq($genValue);", - s"(${cls.getName}) $builderValue.result();") + s"(${cls.getName}) $builderValue.result();" + ) case None => // array - (s"""$convertedType[] $convertedArray = null; - $convertedArray = $arrayConstructor;""", + ( + s""" + $convertedType[] $convertedArray = null; + $convertedArray = $arrayConstructor; + """, genValue => s"$convertedArray[$loopIndex] = $genValue;", - s"new ${classOf[GenericArrayData].getName}($convertedArray);") + s"new ${classOf[GenericArrayData].getName}($convertedArray);" + ) } val code = s""" diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala index 802397d50e85..e5a3e1fd374d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala @@ -33,6 +33,10 @@ case class StringIntClass(a: String, b: Int) case class ComplexClass(a: Long, b: StringLongClass) +case class ArrayClass(arr: Seq[StringIntClass]) + +case class NestedArrayClass(nestedArr: Array[ArrayClass]) + class EncoderResolutionSuite extends PlanTest { private val str = UTF8String.fromString("hello") @@ -62,6 +66,54 @@ class EncoderResolutionSuite extends PlanTest { encoder.resolveAndBind(attrs).fromRow(InternalRow(InternalRow(str, 1.toByte), 2)) } + test("real type doesn't match encoder schema but they are compatible: array") { + val encoder = ExpressionEncoder[ArrayClass] + val attrs = Seq('arr.array(new StructType().add("a", "int").add("b", "int").add("c", "int"))) + val array = new GenericArrayData(Array(InternalRow(1, 2, 3))) + encoder.resolveAndBind(attrs).fromRow(InternalRow(array)) + } + + test("real type doesn't match encoder schema but they are compatible: nested array") { + val encoder = ExpressionEncoder[NestedArrayClass] + val et = new StructType().add("arr", ArrayType( + new StructType().add("a", "int").add("b", "int").add("c", "int"))) + val attrs = Seq('nestedArr.array(et)) + val innerArr = new GenericArrayData(Array(InternalRow(1, 2, 3))) + val outerArr = new GenericArrayData(Array(InternalRow(innerArr))) + encoder.resolveAndBind(attrs).fromRow(InternalRow(outerArr)) + } + + test("the real type is not compatible with encoder schema: non-array field") { + val encoder = ExpressionEncoder[ArrayClass] + val attrs = Seq('arr.int) + assert(intercept[AnalysisException](encoder.resolveAndBind(attrs)).message == + "need an array field but got int") + } + + test("the real type is not compatible with encoder schema: array element type") { + val encoder = ExpressionEncoder[ArrayClass] + val attrs = Seq('arr.array(new StructType().add("c", "int"))) + assert(intercept[AnalysisException](encoder.resolveAndBind(attrs)).message == + "No such struct field a in c") + } + + test("the real type is not compatible with encoder schema: nested array element type") { + val encoder = ExpressionEncoder[NestedArrayClass] + + withClue("inner element is not array") { + val attrs = Seq('nestedArr.array(new StructType().add("arr", "int"))) + assert(intercept[AnalysisException](encoder.resolveAndBind(attrs)).message == + "need an array field but got int") + } + + withClue("nested array element type is not compatible") { + val attrs = Seq('nestedArr.array(new StructType() + .add("arr", ArrayType(new StructType().add("c", "int"))))) + assert(intercept[AnalysisException](encoder.resolveAndBind(attrs)).message == + "No such struct field a in c") + } + } + test("nullability of array type element should not fail analysis") { val encoder = ExpressionEncoder[Seq[Int]] val attrs = 'a.array(IntegerType) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/ReduceAggregator.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/ReduceAggregator.scala index 174378304d4a..e266ae55cc4d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/ReduceAggregator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/ReduceAggregator.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder private[sql] class ReduceAggregator[T: Encoder](func: (T, T) => T) extends Aggregator[T, (Boolean, T), T] { - private val encoder = implicitly[Encoder[T]] + @transient private val encoder = implicitly[Encoder[T]] override def zero: (Boolean, T) = (false, null.asInstanceOf[T]) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 68e071a1a694..5b5cd28ad0c9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -142,6 +142,15 @@ class DatasetSuite extends QueryTest with SharedSQLContext { assert(ds.take(2) === Array(ClassData("a", 1), ClassData("b", 2))) } + test("as seq of case class - reorder fields by name") { + val df = spark.range(3).select(array(struct($"id".cast("int").as("b"), lit("a").as("a")))) + val ds = df.as[Seq[ClassData]] + assert(ds.collect() === Array( + Seq(ClassData("a", 0)), + Seq(ClassData("a", 1)), + Seq(ClassData("a", 2)))) + } + test("map") { val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS() checkDataset( From a59759e6c059617b2fc8102cbf41acc5d409b34a Mon Sep 17 00:00:00 2001 From: Seth Hendrickson Date: Tue, 4 Apr 2017 17:04:41 -0700 Subject: [PATCH 209/512] [SPARK-20183][ML] Added outlierRatio arg to MLTestingUtils.testOutliersWithSmallWeights ## What changes were proposed in this pull request? This is a small piece from https://github.com/apache/spark/pull/16722 which ultimately will add sample weights to decision trees. This is to allow more flexibility in testing outliers since linear models and trees behave differently. Note: The primary author when this is committed should be sethah since this is taken from his code. ## How was this patch tested? Existing tests Author: Joseph K. Bradley Closes #17501 from jkbradley/SPARK-20183. --- .../org/apache/spark/ml/classification/LinearSVCSuite.scala | 2 +- .../spark/ml/classification/LogisticRegressionSuite.scala | 2 +- .../org/apache/spark/ml/classification/NaiveBayesSuite.scala | 2 +- .../apache/spark/ml/regression/LinearRegressionSuite.scala | 3 ++- .../test/scala/org/apache/spark/ml/util/MLTestingUtils.scala | 5 +++-- 5 files changed, 8 insertions(+), 6 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala index 4c63a2a88c6c..c763a4cef1af 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala @@ -164,7 +164,7 @@ class LinearSVCSuite extends SparkFunSuite with MLlibTestSparkContext with Defau MLTestingUtils.testArbitrarilyScaledWeights[LinearSVCModel, LinearSVC]( dataset.as[LabeledPoint], estimator, modelEquals) MLTestingUtils.testOutliersWithSmallWeights[LinearSVCModel, LinearSVC]( - dataset.as[LabeledPoint], estimator, 2, modelEquals) + dataset.as[LabeledPoint], estimator, 2, modelEquals, outlierRatio = 3) MLTestingUtils.testOversamplingVsWeighting[LinearSVCModel, LinearSVC]( dataset.as[LabeledPoint], estimator, modelEquals, 42L) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index 1b6448037349..f0648d0936a1 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -1874,7 +1874,7 @@ class LogisticRegressionSuite MLTestingUtils.testArbitrarilyScaledWeights[LogisticRegressionModel, LogisticRegression]( dataset.as[LabeledPoint], estimator, modelEquals) MLTestingUtils.testOutliersWithSmallWeights[LogisticRegressionModel, LogisticRegression]( - dataset.as[LabeledPoint], estimator, numClasses, modelEquals) + dataset.as[LabeledPoint], estimator, numClasses, modelEquals, outlierRatio = 3) MLTestingUtils.testOversamplingVsWeighting[LogisticRegressionModel, LogisticRegression]( dataset.as[LabeledPoint], estimator, modelEquals, seed) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala index 4d5d299d1408..d41c5b533ded 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala @@ -178,7 +178,7 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa MLTestingUtils.testArbitrarilyScaledWeights[NaiveBayesModel, NaiveBayes]( dataset.as[LabeledPoint], estimatorNoSmoothing, modelEquals) MLTestingUtils.testOutliersWithSmallWeights[NaiveBayesModel, NaiveBayes]( - dataset.as[LabeledPoint], estimatorWithSmoothing, numClasses, modelEquals) + dataset.as[LabeledPoint], estimatorWithSmoothing, numClasses, modelEquals, outlierRatio = 3) MLTestingUtils.testOversamplingVsWeighting[NaiveBayesModel, NaiveBayes]( dataset.as[LabeledPoint], estimatorWithSmoothing, modelEquals, seed) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala index 6a51e75e12a3..c6a267b7283d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -842,7 +842,8 @@ class LinearRegressionSuite MLTestingUtils.testArbitrarilyScaledWeights[LinearRegressionModel, LinearRegression]( datasetWithStrongNoise.as[LabeledPoint], estimator, modelEquals) MLTestingUtils.testOutliersWithSmallWeights[LinearRegressionModel, LinearRegression]( - datasetWithStrongNoise.as[LabeledPoint], estimator, numClasses, modelEquals) + datasetWithStrongNoise.as[LabeledPoint], estimator, numClasses, modelEquals, + outlierRatio = 3) MLTestingUtils.testOversamplingVsWeighting[LinearRegressionModel, LinearRegression]( datasetWithStrongNoise.as[LabeledPoint], estimator, modelEquals, seed) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala index f1ed568d5e60..578f31c8e7db 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala @@ -260,12 +260,13 @@ object MLTestingUtils extends SparkFunSuite { data: Dataset[LabeledPoint], estimator: E with HasWeightCol, numClasses: Int, - modelEquals: (M, M) => Unit): Unit = { + modelEquals: (M, M) => Unit, + outlierRatio: Int): Unit = { import data.sqlContext.implicits._ val outlierDS = data.withColumn("weight", lit(1.0)).as[Instance].flatMap { case Instance(l, w, f) => val outlierLabel = if (numClasses == 0) -l else numClasses - l - 1 - List.fill(3)(Instance(outlierLabel, 0.0001, f)) ++ List(Instance(l, w, f)) + List.fill(outlierRatio)(Instance(outlierLabel, 0.0001, f)) ++ List(Instance(l, w, f)) } val trueModel = estimator.set(estimator.weightCol, "").fit(data) val outlierModel = estimator.set(estimator.weightCol, "weight").fit(outlierDS) From b28bbffbadf7ebc4349666e8f17111f6fca18c9a Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Tue, 4 Apr 2017 17:51:45 -0700 Subject: [PATCH 210/512] [SPARK-20003][ML] FPGrowthModel setMinConfidence should affect rules generation and transform ## What changes were proposed in this pull request? jira: https://issues.apache.org/jira/browse/SPARK-20003 I was doing some test and found the issue. ml.fpm.FPGrowthModel `setMinConfidence` should always affect rules generation and transform. Currently associationRules in FPGrowthModel is a lazy val and `setMinConfidence` in FPGrowthModel has no impact once associationRules got computed . I try to cache the associationRules to avoid re-computation if `minConfidence` is not changed, but this makes FPGrowthModel somehow stateful. Let me know if there's any concern. ## How was this patch tested? new unit test and I strength the unit test for model save/load to ensure the cache mechanism. Author: Yuhao Yang Closes #17336 from hhbyyh/fpmodelminconf. --- .../org/apache/spark/ml/fpm/FPGrowth.scala | 21 ++++++- .../apache/spark/ml/fpm/FPGrowthSuite.scala | 56 +++++++++++++------ 2 files changed, 56 insertions(+), 21 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala index 65cc80619569..d604c1ac001a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala @@ -218,13 +218,28 @@ class FPGrowthModel private[ml] ( def setPredictionCol(value: String): this.type = set(predictionCol, value) /** - * Get association rules fitted by AssociationRules using the minConfidence. Returns a dataframe + * Cache minConfidence and associationRules to avoid redundant computation for association rules + * during transform. The associationRules will only be re-computed when minConfidence changed. + */ + @transient private var _cachedMinConf: Double = Double.NaN + + @transient private var _cachedRules: DataFrame = _ + + /** + * Get association rules fitted using the minConfidence. Returns a dataframe * with three fields, "antecedent", "consequent" and "confidence", where "antecedent" and * "consequent" are Array[T] and "confidence" is Double. */ @Since("2.2.0") - @transient lazy val associationRules: DataFrame = { - AssociationRules.getAssociationRulesFromFP(freqItemsets, "items", "freq", $(minConfidence)) + @transient def associationRules: DataFrame = { + if ($(minConfidence) == _cachedMinConf) { + _cachedRules + } else { + _cachedRules = AssociationRules + .getAssociationRulesFromFP(freqItemsets, "items", "freq", $(minConfidence)) + _cachedMinConf = $(minConfidence) + _cachedRules + } } /** diff --git a/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala index 4603a618d2f9..6bec057511cd 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.ml.fpm import org.apache.spark.SparkFunSuite -import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession} import org.apache.spark.sql.functions._ @@ -85,38 +85,58 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul assert(prediction.select("prediction").where("id=3").first().getSeq[String](0).isEmpty) } + test("FPGrowth prediction should not contain duplicates") { + // This should generate rule 1 -> 3, 2 -> 3 + val dataset = spark.createDataFrame(Seq( + Array("1", "3"), + Array("2", "3") + ).map(Tuple1(_))).toDF("items") + val model = new FPGrowth().fit(dataset) + + val prediction = model.transform( + spark.createDataFrame(Seq(Tuple1(Array("1", "2")))).toDF("items") + ).first().getAs[Seq[String]]("prediction") + + assert(prediction === Seq("3")) + } + + test("FPGrowthModel setMinConfidence should affect rules generation and transform") { + val model = new FPGrowth().setMinSupport(0.1).setMinConfidence(0.1).fit(dataset) + val oldRulesNum = model.associationRules.count() + val oldPredict = model.transform(dataset) + + model.setMinConfidence(0.8765) + assert(oldRulesNum > model.associationRules.count()) + assert(!model.transform(dataset).collect().toSet.equals(oldPredict.collect().toSet)) + + // association rules should stay the same for same minConfidence + model.setMinConfidence(0.1) + assert(oldRulesNum === model.associationRules.count()) + assert(model.transform(dataset).collect().toSet.equals(oldPredict.collect().toSet)) + } + test("FPGrowth parameter check") { val fpGrowth = new FPGrowth().setMinSupport(0.4567) val model = fpGrowth.fit(dataset) .setMinConfidence(0.5678) assert(fpGrowth.getMinSupport === 0.4567) assert(model.getMinConfidence === 0.5678) + MLTestingUtils.checkCopy(model) } test("read/write") { def checkModelData(model: FPGrowthModel, model2: FPGrowthModel): Unit = { - assert(model.freqItemsets.sort("items").collect() === - model2.freqItemsets.sort("items").collect()) + assert(model.freqItemsets.collect().toSet.equals( + model2.freqItemsets.collect().toSet)) + assert(model.associationRules.collect().toSet.equals( + model2.associationRules.collect().toSet)) + assert(model.setMinConfidence(0.9).associationRules.collect().toSet.equals( + model2.setMinConfidence(0.9).associationRules.collect().toSet)) } val fPGrowth = new FPGrowth() testEstimatorAndModelReadWrite(fPGrowth, dataset, FPGrowthSuite.allParamSettings, FPGrowthSuite.allParamSettings, checkModelData) } - - test("FPGrowth prediction should not contain duplicates") { - // This should generate rule 1 -> 3, 2 -> 3 - val dataset = spark.createDataFrame(Seq( - Array("1", "3"), - Array("2", "3") - ).map(Tuple1(_))).toDF("items") - val model = new FPGrowth().fit(dataset) - - val prediction = model.transform( - spark.createDataFrame(Seq(Tuple1(Array("1", "2")))).toDF("items") - ).first().getAs[Seq[String]]("prediction") - - assert(prediction === Seq("3")) - } } object FPGrowthSuite { From c1b8b667506ed95c6c2808e7d3db8463435e73f6 Mon Sep 17 00:00:00 2001 From: Felix Cheung Date: Tue, 4 Apr 2017 22:32:46 -0700 Subject: [PATCH 211/512] [SPARKR][DOC] update doc for fpgrowth ## What changes were proposed in this pull request? minor update zero323 Author: Felix Cheung Closes #17526 from felixcheung/rfpgrowthfollowup. --- R/pkg/R/mllib_clustering.R | 6 +----- R/pkg/R/mllib_fpm.R | 4 ++++ 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/R/pkg/R/mllib_clustering.R b/R/pkg/R/mllib_clustering.R index 0ebdb5a27308..97c9fa1b4584 100644 --- a/R/pkg/R/mllib_clustering.R +++ b/R/pkg/R/mllib_clustering.R @@ -498,11 +498,7 @@ setMethod("write.ml", signature(object = "KMeansModel", path = "character"), #' @export #' @examples #' \dontrun{ -#' # nolint start -#' # An example "path/to/file" can be -#' # paste0(Sys.getenv("SPARK_HOME"), "/data/mllib/sample_lda_libsvm_data.txt") -#' # nolint end -#' text <- read.df("path/to/file", source = "libsvm") +#' text <- read.df("data/mllib/sample_lda_libsvm_data.txt", source = "libsvm") #' model <- spark.lda(data = text, optimizer = "em") #' #' # get a summary of the model diff --git a/R/pkg/R/mllib_fpm.R b/R/pkg/R/mllib_fpm.R index 96251b2c7c19..dfcb45a1b66c 100644 --- a/R/pkg/R/mllib_fpm.R +++ b/R/pkg/R/mllib_fpm.R @@ -27,6 +27,10 @@ setClass("FPGrowthModel", slots = list(jobj = "jobj")) #' FP-growth #' #' A parallel FP-growth algorithm to mine frequent itemsets. +#' \code{spark.fpGrowth} fits a FP-growth model on a SparkDataFrame. Users can +#' \code{spark.freqItemsets} to get frequent itemsets, \code{spark.associationRules} to get +#' association rules, \code{predict} to make predictions on new data based on generated association +#' rules, and \code{write.ml}/\code{read.ml} to save/load fitted models. #' For more details, see #' \href{https://spark.apache.org/docs/latest/mllib-frequent-pattern-mining.html#fp-growth}{ #' FP-growth}. From b6e71032d92a072b7c951e5ea641e9454b5e70ed Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 4 Apr 2017 22:46:42 -0700 Subject: [PATCH 212/512] Small doc fix for ReuseSubquery. --- .../main/scala/org/apache/spark/sql/execution/subquery.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala index 58be2d1da281..d11045fb6ac8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala @@ -150,7 +150,7 @@ case class PlanSubqueries(sparkSession: SparkSession) extends Rule[SparkPlan] { /** - * Find out duplicated exchanges in the spark plan, then use the same exchange for all the + * Find out duplicated subqueries in the spark plan, then use the same subquery result for all the * references. */ case class ReuseSubquery(conf: SQLConf) extends Rule[SparkPlan] { @@ -159,7 +159,7 @@ case class ReuseSubquery(conf: SQLConf) extends Rule[SparkPlan] { if (!conf.exchangeReuseEnabled) { return plan } - // Build a hash map using schema of exchanges to avoid O(N*N) sameResult calls. + // Build a hash map using schema of subqueries to avoid O(N*N) sameResult calls. val subqueries = mutable.HashMap[StructType, ArrayBuffer[SubqueryExec]]() plan transformAllExpressions { case sub: ExecSubqueryExpression => From dad499f324c6a93650aecfeb8cde10a405372930 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 4 Apr 2017 23:20:17 -0700 Subject: [PATCH 213/512] [SPARK-20209][SS] Execute next trigger immediately if previous batch took longer than trigger interval ## What changes were proposed in this pull request? For large trigger intervals (e.g. 10 minutes), if a batch takes 11 minutes, then it will wait for 9 mins before starting the next batch. This does not make sense. The processing time based trigger policy should be to do process batches as fast as possible, but no faster than 1 in every trigger interval. If batches are taking longer than trigger interval anyways, then no point waiting extra trigger interval. In this PR, I modified the ProcessingTimeExecutor to do so. Another minor change I did was to extract our StreamManualClock into a separate class so that it can be used outside subclasses of StreamTest. For example, ProcessingTimeExecutorSuite does not need to create any context for testing, just needs the StreamManualClock. ## How was this patch tested? Added new unit tests to comprehensively test this behavior. Author: Tathagata Das Closes #17525 from tdas/SPARK-20209. --- .../spark/sql/kafka010/KafkaSourceSuite.scala | 1 + .../execution/streaming/TriggerExecutor.scala | 17 ++-- .../ProcessingTimeExecutorSuite.scala | 83 ++++++++++++++++-- .../sql/streaming/FileStreamSourceSuite.scala | 1 + .../FlatMapGroupsWithStateSuite.scala | 3 +- .../spark/sql/streaming/StreamSuite.scala | 1 + .../spark/sql/streaming/StreamTest.scala | 20 +---- .../streaming/StreamingAggregationSuite.scala | 1 + .../StreamingQueryListenerSuite.scala | 1 + .../sql/streaming/StreamingQuerySuite.scala | 87 +++++++++++-------- .../streaming/util/StreamManualClock.scala | 51 +++++++++++ 11 files changed, 194 insertions(+), 72 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/streaming/util/StreamManualClock.scala diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala index 6391d6269c5a..0046ba7e43d1 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala @@ -39,6 +39,7 @@ import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.functions.{count, window} import org.apache.spark.sql.kafka010.KafkaSourceProvider._ import org.apache.spark.sql.streaming.{ProcessingTime, StreamTest} +import org.apache.spark.sql.streaming.util.StreamManualClock import org.apache.spark.sql.test.{SharedSQLContext, TestSparkSession} import org.apache.spark.util.Utils diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TriggerExecutor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TriggerExecutor.scala index 02996ac854f6..d188566f822b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TriggerExecutor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TriggerExecutor.scala @@ -47,21 +47,22 @@ case class ProcessingTimeExecutor(processingTime: ProcessingTime, clock: Clock = extends TriggerExecutor with Logging { private val intervalMs = processingTime.intervalMs + require(intervalMs >= 0) - override def execute(batchRunner: () => Boolean): Unit = { + override def execute(triggerHandler: () => Boolean): Unit = { while (true) { - val batchStartTimeMs = clock.getTimeMillis() - val terminated = !batchRunner() + val triggerTimeMs = clock.getTimeMillis + val nextTriggerTimeMs = nextBatchTime(triggerTimeMs) + val terminated = !triggerHandler() if (intervalMs > 0) { - val batchEndTimeMs = clock.getTimeMillis() - val batchElapsedTimeMs = batchEndTimeMs - batchStartTimeMs + val batchElapsedTimeMs = clock.getTimeMillis - triggerTimeMs if (batchElapsedTimeMs > intervalMs) { notifyBatchFallingBehind(batchElapsedTimeMs) } if (terminated) { return } - clock.waitTillTime(nextBatchTime(batchEndTimeMs)) + clock.waitTillTime(nextTriggerTimeMs) } else { if (terminated) { return @@ -70,7 +71,7 @@ case class ProcessingTimeExecutor(processingTime: ProcessingTime, clock: Clock = } } - /** Called when a batch falls behind. Expose for test only */ + /** Called when a batch falls behind */ def notifyBatchFallingBehind(realElapsedTimeMs: Long): Unit = { logWarning("Current batch is falling behind. The trigger interval is " + s"${intervalMs} milliseconds, but spent ${realElapsedTimeMs} milliseconds") @@ -83,6 +84,6 @@ case class ProcessingTimeExecutor(processingTime: ProcessingTime, clock: Clock = * an interval of `100 ms`, `nextBatchTime(nextBatchTime(0)) = 200` rather than `0`). */ def nextBatchTime(now: Long): Long = { - now / intervalMs * intervalMs + intervalMs + if (intervalMs == 0) now else now / intervalMs * intervalMs + intervalMs } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ProcessingTimeExecutorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ProcessingTimeExecutorSuite.scala index 00d5e051de35..007554a83f54 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ProcessingTimeExecutorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ProcessingTimeExecutorSuite.scala @@ -17,14 +17,24 @@ package org.apache.spark.sql.execution.streaming -import java.util.concurrent.{CountDownLatch, TimeUnit} +import java.util.concurrent.ConcurrentHashMap + +import scala.collection.mutable + +import org.eclipse.jetty.util.ConcurrentHashSet +import org.scalatest.concurrent.Eventually +import org.scalatest.concurrent.PatienceConfiguration.Timeout +import org.scalatest.concurrent.Timeouts._ +import org.scalatest.time.SpanSugar._ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.streaming.ProcessingTime -import org.apache.spark.util.{Clock, ManualClock, SystemClock} +import org.apache.spark.sql.streaming.util.StreamManualClock class ProcessingTimeExecutorSuite extends SparkFunSuite { + val timeout = 10.seconds + test("nextBatchTime") { val processingTimeExecutor = ProcessingTimeExecutor(ProcessingTime(100)) assert(processingTimeExecutor.nextBatchTime(0) === 100) @@ -35,6 +45,57 @@ class ProcessingTimeExecutorSuite extends SparkFunSuite { assert(processingTimeExecutor.nextBatchTime(150) === 200) } + test("trigger timing") { + val triggerTimes = new ConcurrentHashSet[Int] + val clock = new StreamManualClock() + @volatile var continueExecuting = true + @volatile var clockIncrementInTrigger = 0L + val executor = ProcessingTimeExecutor(ProcessingTime("1000 milliseconds"), clock) + val executorThread = new Thread() { + override def run(): Unit = { + executor.execute(() => { + // Record the trigger time, increment clock if needed and + triggerTimes.add(clock.getTimeMillis.toInt) + clock.advance(clockIncrementInTrigger) + clockIncrementInTrigger = 0 // reset this so that there are no runaway triggers + continueExecuting + }) + } + } + executorThread.start() + // First batch should execute immediately, then executor should wait for next one + eventually { + assert(triggerTimes.contains(0)) + assert(clock.isStreamWaitingAt(0)) + assert(clock.isStreamWaitingFor(1000)) + } + + // Second batch should execute when clock reaches the next trigger time. + // If next trigger takes less than the trigger interval, executor should wait for next one + clockIncrementInTrigger = 500 + clock.setTime(1000) + eventually { + assert(triggerTimes.contains(1000)) + assert(clock.isStreamWaitingAt(1500)) + assert(clock.isStreamWaitingFor(2000)) + } + + // If next trigger takes less than the trigger interval, executor should immediately execute + // another one + clockIncrementInTrigger = 1500 + clock.setTime(2000) // allow another trigger by setting clock to 2000 + eventually { + // Since the next trigger will take 1500 (which is more than trigger interval of 1000) + // executor will immediately execute another trigger + assert(triggerTimes.contains(2000) && triggerTimes.contains(3500)) + assert(clock.isStreamWaitingAt(3500)) + assert(clock.isStreamWaitingFor(4000)) + } + continueExecuting = false + clock.advance(1000) + waitForThreadJoin(executorThread) + } + test("calling nextBatchTime with the result of a previous call should return the next interval") { val intervalMS = 100 val processingTimeExecutor = ProcessingTimeExecutor(ProcessingTime(intervalMS)) @@ -54,7 +115,7 @@ class ProcessingTimeExecutorSuite extends SparkFunSuite { val processingTimeExecutor = ProcessingTimeExecutor(ProcessingTime(intervalMs)) processingTimeExecutor.execute(() => { batchCounts += 1 - // If the batch termination works well, batchCounts should be 3 after `execute` + // If the batch termination works correctly, batchCounts should be 3 after `execute` batchCounts < 3 }) assert(batchCounts === 3) @@ -66,9 +127,8 @@ class ProcessingTimeExecutorSuite extends SparkFunSuite { } test("notifyBatchFallingBehind") { - val clock = new ManualClock() + val clock = new StreamManualClock() @volatile var batchFallingBehindCalled = false - val latch = new CountDownLatch(1) val t = new Thread() { override def run(): Unit = { val processingTimeExecutor = new ProcessingTimeExecutor(ProcessingTime(100), clock) { @@ -77,7 +137,6 @@ class ProcessingTimeExecutorSuite extends SparkFunSuite { } } processingTimeExecutor.execute(() => { - latch.countDown() clock.waitTillTime(200) false }) @@ -85,9 +144,17 @@ class ProcessingTimeExecutorSuite extends SparkFunSuite { } t.start() // Wait until the batch is running so that we don't call `advance` too early - assert(latch.await(10, TimeUnit.SECONDS), "the batch has not yet started in 10 seconds") + eventually { assert(clock.isStreamWaitingFor(200)) } clock.advance(200) - t.join() + waitForThreadJoin(t) assert(batchFallingBehindCalled === true) } + + private def eventually(body: => Unit): Unit = { + Eventually.eventually(Timeout(timeout)) { body } + } + + private def waitForThreadJoin(thread: Thread): Unit = { + failAfter(timeout) { thread.join() } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala index 171877abe6e9..26967782f77c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala @@ -33,6 +33,7 @@ import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.FileStreamSource.{FileEntry, SeenFilesMap} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.ExistsThrowsExceptionFileSystem._ +import org.apache.spark.sql.streaming.util.StreamManualClock import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.util.Utils diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala index c8e31e3ca2e0..85aa7dbe9ed8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala @@ -21,8 +21,6 @@ import java.sql.Date import java.util.concurrent.ConcurrentHashMap import org.scalatest.BeforeAndAfterAll -import org.scalatest.concurrent.Eventually.eventually -import org.scalatest.concurrent.PatienceConfiguration.Timeout import org.apache.spark.SparkException import org.apache.spark.api.java.function.FlatMapGroupsWithStateFunction @@ -35,6 +33,7 @@ import org.apache.spark.sql.execution.RDDScanExec import org.apache.spark.sql.execution.streaming.{FlatMapGroupsWithStateExec, GroupStateImpl, MemoryStream} import org.apache.spark.sql.execution.streaming.state.{StateStore, StateStoreId, StoreUpdate} import org.apache.spark.sql.streaming.FlatMapGroupsWithStateSuite.MemoryStateStore +import org.apache.spark.sql.streaming.util.StreamManualClock import org.apache.spark.sql.types.{DataType, IntegerType} /** Class to check custom state types */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala index 388f15405e70..5ab9dc2bc776 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.StreamSourceProvider +import org.apache.spark.sql.streaming.util.StreamManualClock import org.apache.spark.sql.types.{IntegerType, StructField, StructType} import org.apache.spark.util.Utils diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index 951ff2ca0d68..03aa45b61688 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -214,24 +214,6 @@ trait StreamTest extends QueryTest with SharedSQLContext with Timeouts { AssertOnQuery(query => { func(query); true }) } - class StreamManualClock(time: Long = 0L) extends ManualClock(time) with Serializable { - private var waitStartTime: Option[Long] = None - - override def waitTillTime(targetTime: Long): Long = synchronized { - try { - waitStartTime = Some(getTimeMillis()) - super.waitTillTime(targetTime) - } finally { - waitStartTime = None - } - } - - def isStreamWaitingAt(time: Long): Boolean = synchronized { - waitStartTime == Some(time) - } - } - - /** * Executes the specified actions on the given streaming DataFrame and provides helpful * error messages in the case of failures or incorrect answers. @@ -242,6 +224,8 @@ trait StreamTest extends QueryTest with SharedSQLContext with Timeouts { def testStream( _stream: Dataset[_], outputMode: OutputMode = OutputMode.Append)(actions: StreamAction*): Unit = synchronized { + import org.apache.spark.sql.streaming.util.StreamManualClock + // `synchronized` is added to prevent the user from calling multiple `testStream`s concurrently // because this method assumes there is only one active query in its `StreamingQueryListener` // and it may not work correctly when multiple `testStream`s run concurrently. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala index 600c039cd0b9..e5d5b4f32882 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala @@ -30,6 +30,7 @@ import org.apache.spark.sql.execution.streaming.state.StateStore import org.apache.spark.sql.expressions.scalalang.typed import org.apache.spark.sql.functions._ import org.apache.spark.sql.streaming.OutputMode._ +import org.apache.spark.sql.streaming.util.StreamManualClock object FailureSinglton { var firstTime = true diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala index 03dad8a6ddbc..b8a694c17731 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala @@ -35,6 +35,7 @@ import org.apache.spark.sql.{Encoder, SparkSession} import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.StreamingQueryListener._ +import org.apache.spark.sql.streaming.util.StreamManualClock import org.apache.spark.util.JsonProtocol class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala index 1172531fe998..2ebbfcd22b97 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala @@ -34,7 +34,7 @@ import org.apache.spark.SparkException import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.streaming.util.{BlockingSource, MockSourceProvider} +import org.apache.spark.sql.streaming.util.{BlockingSource, MockSourceProvider, StreamManualClock} import org.apache.spark.util.ManualClock @@ -207,46 +207,53 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi /** Custom MemoryStream that waits for manual clock to reach a time */ val inputData = new MemoryStream[Int](0, sqlContext) { - // Wait for manual clock to be 100 first time there is data + // getOffset should take 50 ms the first time it is called override def getOffset: Option[Offset] = { val offset = super.getOffset if (offset.nonEmpty) { - clock.waitTillTime(300) + clock.waitTillTime(1050) } offset } - // Wait for manual clock to be 300 first time there is data + // getBatch should take 100 ms the first time it is called override def getBatch(start: Option[Offset], end: Offset): DataFrame = { - clock.waitTillTime(600) + if (start.isEmpty) clock.waitTillTime(1150) super.getBatch(start, end) } } - // This is to make sure thatquery waits for manual clock to be 600 first time there is data - val mapped = inputData.toDS().as[Long].map { x => - clock.waitTillTime(1100) + // query execution should take 350 ms the first time it is called + val mapped = inputData.toDS.coalesce(1).as[Long].map { x => + clock.waitTillTime(1500) // this will only wait the first time when clock < 1500 10 / x }.agg(count("*")).as[Long] - case class AssertStreamExecThreadToWaitForClock() + case class AssertStreamExecThreadIsWaitingForTime(targetTime: Long) extends AssertOnQuery(q => { eventually(Timeout(streamingTimeout)) { if (q.exception.isEmpty) { - assert(clock.asInstanceOf[StreamManualClock].isStreamWaitingAt(clock.getTimeMillis)) + assert(clock.isStreamWaitingFor(targetTime)) } } if (q.exception.isDefined) { throw q.exception.get } true - }, "") + }, "") { + override def toString: String = s"AssertStreamExecThreadIsWaitingForTime($targetTime)" + } + + case class AssertClockTime(time: Long) + extends AssertOnQuery(q => clock.getTimeMillis() === time, "") { + override def toString: String = s"AssertClockTime($time)" + } var lastProgressBeforeStop: StreamingQueryProgress = null testStream(mapped, OutputMode.Complete)( - StartStream(ProcessingTime(100), triggerClock = clock), - AssertStreamExecThreadToWaitForClock(), + StartStream(ProcessingTime(1000), triggerClock = clock), + AssertStreamExecThreadIsWaitingForTime(1000), AssertOnQuery(_.status.isDataAvailable === false), AssertOnQuery(_.status.isTriggerActive === false), AssertOnQuery(_.status.message === "Waiting for next trigger"), @@ -254,33 +261,37 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi // Test status and progress while offset is being fetched AddData(inputData, 1, 2), - AdvanceManualClock(100), // time = 100 to start new trigger, will block on getOffset - AssertStreamExecThreadToWaitForClock(), + AdvanceManualClock(1000), // time = 1000 to start new trigger, will block on getOffset + AssertStreamExecThreadIsWaitingForTime(1050), AssertOnQuery(_.status.isDataAvailable === false), AssertOnQuery(_.status.isTriggerActive === true), AssertOnQuery(_.status.message.startsWith("Getting offsets from")), AssertOnQuery(_.recentProgress.count(_.numInputRows > 0) === 0), // Test status and progress while batch is being fetched - AdvanceManualClock(200), // time = 300 to unblock getOffset, will block on getBatch - AssertStreamExecThreadToWaitForClock(), + AdvanceManualClock(50), // time = 1050 to unblock getOffset + AssertClockTime(1050), + AssertStreamExecThreadIsWaitingForTime(1150), // will block on getBatch that needs 1150 AssertOnQuery(_.status.isDataAvailable === true), AssertOnQuery(_.status.isTriggerActive === true), AssertOnQuery(_.status.message === "Processing new data"), AssertOnQuery(_.recentProgress.count(_.numInputRows > 0) === 0), // Test status and progress while batch is being processed - AdvanceManualClock(300), // time = 600 to unblock getBatch, will block in Spark job + AdvanceManualClock(100), // time = 1150 to unblock getBatch + AssertClockTime(1150), + AssertStreamExecThreadIsWaitingForTime(1500), // will block in Spark job that needs 1500 AssertOnQuery(_.status.isDataAvailable === true), AssertOnQuery(_.status.isTriggerActive === true), AssertOnQuery(_.status.message === "Processing new data"), AssertOnQuery(_.recentProgress.count(_.numInputRows > 0) === 0), // Test status and progress while batch processing has completed - AdvanceManualClock(500), // time = 1100 to unblock job - AssertOnQuery { _ => clock.getTimeMillis() === 1100 }, + AssertOnQuery { _ => clock.getTimeMillis() === 1150 }, + AdvanceManualClock(350), // time = 1500 to unblock job + AssertClockTime(1500), CheckAnswer(2), - AssertStreamExecThreadToWaitForClock(), + AssertStreamExecThreadIsWaitingForTime(2000), AssertOnQuery(_.status.isDataAvailable === true), AssertOnQuery(_.status.isTriggerActive === false), AssertOnQuery(_.status.message === "Waiting for next trigger"), @@ -293,21 +304,21 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi assert(progress.id === query.id) assert(progress.name === query.name) assert(progress.batchId === 0) - assert(progress.timestamp === "1970-01-01T00:00:00.100Z") // 100 ms in UTC + assert(progress.timestamp === "1970-01-01T00:00:01.000Z") // 100 ms in UTC assert(progress.numInputRows === 2) - assert(progress.processedRowsPerSecond === 2.0) + assert(progress.processedRowsPerSecond === 4.0) - assert(progress.durationMs.get("getOffset") === 200) - assert(progress.durationMs.get("getBatch") === 300) + assert(progress.durationMs.get("getOffset") === 50) + assert(progress.durationMs.get("getBatch") === 100) assert(progress.durationMs.get("queryPlanning") === 0) assert(progress.durationMs.get("walCommit") === 0) - assert(progress.durationMs.get("triggerExecution") === 1000) + assert(progress.durationMs.get("triggerExecution") === 500) assert(progress.sources.length === 1) assert(progress.sources(0).description contains "MemoryStream") assert(progress.sources(0).startOffset === null) assert(progress.sources(0).endOffset !== null) - assert(progress.sources(0).processedRowsPerSecond === 2.0) + assert(progress.sources(0).processedRowsPerSecond === 4.0) // 2 rows processed in 500 ms assert(progress.stateOperators.length === 1) assert(progress.stateOperators(0).numRowsUpdated === 1) @@ -317,9 +328,12 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi true }, + // Test whether input rate is updated after two batches + AssertStreamExecThreadIsWaitingForTime(2000), // blocked waiting for next trigger time AddData(inputData, 1, 2), - AdvanceManualClock(100), // allow another trigger - AssertStreamExecThreadToWaitForClock(), + AdvanceManualClock(500), // allow another trigger + AssertClockTime(2000), + AssertStreamExecThreadIsWaitingForTime(3000), // will block waiting for next trigger time CheckAnswer(4), AssertOnQuery(_.status.isDataAvailable === true), AssertOnQuery(_.status.isTriggerActive === false), @@ -327,13 +341,14 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi AssertOnQuery { query => assert(query.recentProgress.last.eq(query.lastProgress)) assert(query.lastProgress.batchId === 1) - assert(query.lastProgress.sources(0).inputRowsPerSecond === 1.818) + assert(query.lastProgress.inputRowsPerSecond === 2.0) + assert(query.lastProgress.sources(0).inputRowsPerSecond === 2.0) true }, // Test status and progress after data is not available for a trigger - AdvanceManualClock(100), // allow another trigger - AssertStreamExecThreadToWaitForClock(), + AdvanceManualClock(1000), // allow another trigger + AssertStreamExecThreadIsWaitingForTime(4000), AssertOnQuery(_.status.isDataAvailable === false), AssertOnQuery(_.status.isTriggerActive === false), AssertOnQuery(_.status.message === "Waiting for next trigger"), @@ -350,10 +365,10 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi AssertOnQuery(_.status.message === "Stopped"), // Test status and progress after query terminated with error - StartStream(ProcessingTime(100), triggerClock = clock), - AdvanceManualClock(100), // ensure initial trigger completes before AddData + StartStream(ProcessingTime(1000), triggerClock = clock), + AdvanceManualClock(1000), // ensure initial trigger completes before AddData AddData(inputData, 0), - AdvanceManualClock(100), // allow another trigger + AdvanceManualClock(1000), // allow another trigger ExpectFailure[SparkException](), AssertOnQuery(_.status.isDataAvailable === false), AssertOnQuery(_.status.isTriggerActive === false), @@ -678,5 +693,5 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi object StreamingQuerySuite { // Singleton reference to clock that does not get serialized in task closures - var clock: ManualClock = null + var clock: StreamManualClock = null } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/util/StreamManualClock.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/util/StreamManualClock.scala new file mode 100644 index 000000000000..c769a790a416 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/util/StreamManualClock.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming.util + +import org.apache.spark.util.ManualClock + +/** + * ManualClock used for streaming tests that allows checking whether the stream is waiting + * on the clock at expected times. + */ +class StreamManualClock(time: Long = 0L) extends ManualClock(time) with Serializable { + private var waitStartTime: Option[Long] = None + private var waitTargetTime: Option[Long] = None + + override def waitTillTime(targetTime: Long): Long = synchronized { + try { + waitStartTime = Some(getTimeMillis()) + waitTargetTime = Some(targetTime) + super.waitTillTime(targetTime) + } finally { + waitStartTime = None + waitTargetTime = None + } + } + + /** Is the streaming thread waiting for the clock to advance when it is at the given time */ + def isStreamWaitingAt(time: Long): Boolean = synchronized { + waitStartTime == Some(time) + } + + /** Is the streaming thread waiting for clock to advance to the given time */ + def isStreamWaitingFor(target: Long): Boolean = synchronized { + waitTargetTime == Some(target) + } +} + From 6f09dc70d9808cae004ceda9ad615aa9be50f43d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Oliver=20K=C3=B6th?= Date: Wed, 5 Apr 2017 08:09:42 +0100 Subject: [PATCH 214/512] [SPARK-20042][WEB UI] Fix log page buttons for reverse proxy mode MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit with spark.ui.reverseProxy=true, full path URLs like /log will point to the master web endpoint which is serving the worker UI as reverse proxy. To access a REST endpoint in the worker in reverse proxy mode , the leading /proxy/"target"/ part of the base URI must be retained. Added logic to log-view.js to handle this, similar to executorspage.js Patch was tested manually Author: Oliver Köth Closes #17370 from okoethibm/master. --- .../org/apache/spark/ui/static/log-view.js | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/core/src/main/resources/org/apache/spark/ui/static/log-view.js b/core/src/main/resources/org/apache/spark/ui/static/log-view.js index 1782b4f209c0..b5c43e5788bc 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/log-view.js +++ b/core/src/main/resources/org/apache/spark/ui/static/log-view.js @@ -51,13 +51,26 @@ function noNewAlert() { window.setTimeout(function () {alert.css("display", "none");}, 4000); } + +function getRESTEndPoint() { + // If the worker is served from the master through a proxy (see doc on spark.ui.reverseProxy), + // we need to retain the leading ../proxy// part of the URL when making REST requests. + // Similar logic is contained in executorspage.js function createRESTEndPoint. + var words = document.baseURI.split('/'); + var ind = words.indexOf("proxy"); + if (ind > 0) { + return words.slice(0, ind + 2).join('/') + "/log"; + } + return "/log" +} + function loadMore() { var offset = Math.max(startByte - byteLength, 0); var moreByteLength = Math.min(byteLength, startByte); $.ajax({ type: "GET", - url: "/log" + baseParams + "&offset=" + offset + "&byteLength=" + moreByteLength, + url: getRESTEndPoint() + baseParams + "&offset=" + offset + "&byteLength=" + moreByteLength, success: function (data) { var oldHeight = $(".log-content")[0].scrollHeight; var newlineIndex = data.indexOf('\n'); @@ -83,14 +96,14 @@ function loadMore() { function loadNew() { $.ajax({ type: "GET", - url: "/log" + baseParams + "&byteLength=0", + url: getRESTEndPoint() + baseParams + "&byteLength=0", success: function (data) { var dataInfo = data.substring(0, data.indexOf('\n')).match(/\d+/g); var newDataLen = dataInfo[2] - totalLogLength; if (newDataLen != 0) { $.ajax({ type: "GET", - url: "/log" + baseParams + "&byteLength=" + newDataLen, + url: getRESTEndPoint() + baseParams + "&byteLength=" + newDataLen, success: function (data) { var newlineIndex = data.indexOf('\n'); var dataInfo = data.substring(0, newlineIndex).match(/\d+/g); From 71c3c48159fe7eb4a46fc2a1b78b72088ccfa824 Mon Sep 17 00:00:00 2001 From: shaolinliu Date: Wed, 5 Apr 2017 13:47:44 +0100 Subject: [PATCH 215/512] [SPARK-19807][WEB UI] Add reason for cancellation when a stage is killed using web UI ## What changes were proposed in this pull request? When a user kills a stage using web UI (in Stages page), StagesTab.handleKillRequest requests SparkContext to cancel the stage without giving a reason. SparkContext has cancelStage(stageId: Int, reason: String) that Spark could use to pass the information for monitoring/debugging purposes. ## How was this patch tested? manual tests Please review http://spark.apache.org/contributing.html before opening a pull request. Author: shaolinliu Author: lvdongr Closes #17258 from shaolinliu/SPARK-19807. --- core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala index c1f25114371f..181465bdf960 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala @@ -42,7 +42,7 @@ private[ui] class StagesTab(parent: SparkUI) extends SparkUITab(parent, "stages" val stageId = Option(request.getParameter("id")).map(_.toInt) stageId.foreach { id => if (progressListener.activeStages.contains(id)) { - sc.foreach(_.cancelStage(id)) + sc.foreach(_.cancelStage(id, "killed via the Web UI")) // Do a quick pause here to give Spark time to kill the stage so it shows up as // killed after the refresh. Note that this will block the serving thread so the // time should be limited in duration. From a2d8d767d933321426a4eb9df1583e017722d7d6 Mon Sep 17 00:00:00 2001 From: wangzhenhua Date: Wed, 5 Apr 2017 10:21:43 -0700 Subject: [PATCH 216/512] [SPARK-20223][SQL] Fix typo in tpcds q77.sql ## What changes were proposed in this pull request? Fix typo in tpcds q77.sql ## How was this patch tested? N/A Author: wangzhenhua Closes #17538 from wzhfy/typoQ77. --- sql/core/src/test/resources/tpcds/q77.sql | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/resources/tpcds/q77.sql b/sql/core/src/test/resources/tpcds/q77.sql index 7830f96e7651..a69df9fbcd36 100755 --- a/sql/core/src/test/resources/tpcds/q77.sql +++ b/sql/core/src/test/resources/tpcds/q77.sql @@ -36,7 +36,7 @@ WITH ss AS sum(cr_net_loss) AS profit_loss FROM catalog_returns, date_dim WHERE cr_returned_date_sk = d_date_sk - AND d_date BETWEEN cast('2000-08-03]' AS DATE) AND + AND d_date BETWEEN cast('2000-08-03' AS DATE) AND (cast('2000-08-03' AS DATE) + INTERVAL 30 days)), ws AS (SELECT From e2773996b8d1c0214d9ffac634a059b4923caf7b Mon Sep 17 00:00:00 2001 From: zero323 Date: Wed, 5 Apr 2017 11:47:40 -0700 Subject: [PATCH 217/512] [SPARK-19454][PYTHON][SQL] DataFrame.replace improvements ## What changes were proposed in this pull request? - Allows skipping `value` argument if `to_replace` is a `dict`: ```python df = sc.parallelize([("Alice", 1, 3.0)]).toDF() df.replace({"Alice": "Bob"}).show() ```` - Adds validation step to ensure homogeneous values / replacements. - Simplifies internal control flow. - Improves unit tests coverage. ## How was this patch tested? Existing unit tests, additional unit tests, manual testing. Author: zero323 Closes #16793 from zero323/SPARK-19454. --- python/pyspark/sql/dataframe.py | 81 +++++++++++++++++++++++---------- python/pyspark/sql/tests.py | 72 +++++++++++++++++++++++++++++ 2 files changed, 128 insertions(+), 25 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index a24512f53c52..774caf53f3a4 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -25,6 +25,8 @@ else: from itertools import imap as map +import warnings + from pyspark import copy_func, since from pyspark.rdd import RDD, _load_from_socket, ignore_unicode_prefix from pyspark.serializers import BatchedSerializer, PickleSerializer, UTF8Deserializer @@ -1281,7 +1283,7 @@ def fillna(self, value, subset=None): return DataFrame(self._jdf.na().fill(value, self._jseq(subset)), self.sql_ctx) @since(1.4) - def replace(self, to_replace, value, subset=None): + def replace(self, to_replace, value=None, subset=None): """Returns a new :class:`DataFrame` replacing a value with another value. :func:`DataFrame.replace` and :func:`DataFrameNaFunctions.replace` are aliases of each other. @@ -1326,43 +1328,72 @@ def replace(self, to_replace, value, subset=None): |null| null|null| +----+------+----+ """ - if not isinstance(to_replace, (float, int, long, basestring, list, tuple, dict)): + # Helper functions + def all_of(types): + """Given a type or tuple of types and a sequence of xs + check if each x is instance of type(s) + + >>> all_of(bool)([True, False]) + True + >>> all_of(basestring)(["a", 1]) + False + """ + def all_of_(xs): + return all(isinstance(x, types) for x in xs) + return all_of_ + + all_of_bool = all_of(bool) + all_of_str = all_of(basestring) + all_of_numeric = all_of((float, int, long)) + + # Validate input types + valid_types = (bool, float, int, long, basestring, list, tuple) + if not isinstance(to_replace, valid_types + (dict, )): raise ValueError( - "to_replace should be a float, int, long, string, list, tuple, or dict") + "to_replace should be a float, int, long, string, list, tuple, or dict. " + "Got {0}".format(type(to_replace))) - if not isinstance(value, (float, int, long, basestring, list, tuple)): - raise ValueError("value should be a float, int, long, string, list, or tuple") + if not isinstance(value, valid_types) and not isinstance(to_replace, dict): + raise ValueError("If to_replace is not a dict, value should be " + "a float, int, long, string, list, or tuple. " + "Got {0}".format(type(value))) + + if isinstance(to_replace, (list, tuple)) and isinstance(value, (list, tuple)): + if len(to_replace) != len(value): + raise ValueError("to_replace and value lists should be of the same length. " + "Got {0} and {1}".format(len(to_replace), len(value))) - rep_dict = dict() + if not (subset is None or isinstance(subset, (list, tuple, basestring))): + raise ValueError("subset should be a list or tuple of column names, " + "column name or None. Got {0}".format(type(subset))) + # Reshape input arguments if necessary if isinstance(to_replace, (float, int, long, basestring)): to_replace = [to_replace] - if isinstance(to_replace, tuple): - to_replace = list(to_replace) + if isinstance(value, (float, int, long, basestring)): + value = [value for _ in range(len(to_replace))] - if isinstance(value, tuple): - value = list(value) - - if isinstance(to_replace, list) and isinstance(value, list): - if len(to_replace) != len(value): - raise ValueError("to_replace and value lists should be of the same length") - rep_dict = dict(zip(to_replace, value)) - elif isinstance(to_replace, list) and isinstance(value, (float, int, long, basestring)): - rep_dict = dict([(tr, value) for tr in to_replace]) - elif isinstance(to_replace, dict): + if isinstance(to_replace, dict): rep_dict = to_replace + if value is not None: + warnings.warn("to_replace is a dict and value is not None. value will be ignored.") + else: + rep_dict = dict(zip(to_replace, value)) - if subset is None: - return DataFrame(self._jdf.na().replace('*', rep_dict), self.sql_ctx) - elif isinstance(subset, basestring): + if isinstance(subset, basestring): subset = [subset] - if not isinstance(subset, (list, tuple)): - raise ValueError("subset should be a list or tuple of column names") + # Verify we were not passed in mixed type generics." + if not any(all_of_type(rep_dict.keys()) and all_of_type(rep_dict.values()) + for all_of_type in [all_of_bool, all_of_str, all_of_numeric]): + raise ValueError("Mixed type replacements are not supported") - return DataFrame( - self._jdf.na().replace(self._jseq(subset), self._jmap(rep_dict)), self.sql_ctx) + if subset is None: + return DataFrame(self._jdf.na().replace('*', rep_dict), self.sql_ctx) + else: + return DataFrame( + self._jdf.na().replace(self._jseq(subset), self._jmap(rep_dict)), self.sql_ctx) @since(2.0) def approxQuantile(self, col, probabilities, relativeError): diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index db41b4edb6dd..2b2444304e04 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -1779,6 +1779,78 @@ def test_replace(self): self.assertEqual(row.age, 10) self.assertEqual(row.height, None) + # replace with lists + row = self.spark.createDataFrame( + [(u'Alice', 10, 80.1)], schema).replace([u'Alice'], [u'Ann']).first() + self.assertTupleEqual(row, (u'Ann', 10, 80.1)) + + # replace with dict + row = self.spark.createDataFrame( + [(u'Alice', 10, 80.1)], schema).replace({10: 11}).first() + self.assertTupleEqual(row, (u'Alice', 11, 80.1)) + + # test backward compatibility with dummy value + dummy_value = 1 + row = self.spark.createDataFrame( + [(u'Alice', 10, 80.1)], schema).replace({'Alice': 'Bob'}, dummy_value).first() + self.assertTupleEqual(row, (u'Bob', 10, 80.1)) + + # test dict with mixed numerics + row = self.spark.createDataFrame( + [(u'Alice', 10, 80.1)], schema).replace({10: -10, 80.1: 90.5}).first() + self.assertTupleEqual(row, (u'Alice', -10, 90.5)) + + # replace with tuples + row = self.spark.createDataFrame( + [(u'Alice', 10, 80.1)], schema).replace((u'Alice', ), (u'Bob', )).first() + self.assertTupleEqual(row, (u'Bob', 10, 80.1)) + + # replace multiple columns + row = self.spark.createDataFrame( + [(u'Alice', 10, 80.0)], schema).replace((10, 80.0), (20, 90)).first() + self.assertTupleEqual(row, (u'Alice', 20, 90.0)) + + # test for mixed numerics + row = self.spark.createDataFrame( + [(u'Alice', 10, 80.0)], schema).replace((10, 80), (20, 90.5)).first() + self.assertTupleEqual(row, (u'Alice', 20, 90.5)) + + row = self.spark.createDataFrame( + [(u'Alice', 10, 80.0)], schema).replace({10: 20, 80: 90.5}).first() + self.assertTupleEqual(row, (u'Alice', 20, 90.5)) + + # replace with boolean + row = (self + .spark.createDataFrame([(u'Alice', 10, 80.0)], schema) + .selectExpr("name = 'Bob'", 'age <= 15') + .replace(False, True).first()) + self.assertTupleEqual(row, (True, True)) + + # should fail if subset is not list, tuple or None + with self.assertRaises(ValueError): + self.spark.createDataFrame( + [(u'Alice', 10, 80.1)], schema).replace({10: 11}, subset=1).first() + + # should fail if to_replace and value have different length + with self.assertRaises(ValueError): + self.spark.createDataFrame( + [(u'Alice', 10, 80.1)], schema).replace(["Alice", "Bob"], ["Eve"]).first() + + # should fail if when received unexpected type + with self.assertRaises(ValueError): + from datetime import datetime + self.spark.createDataFrame( + [(u'Alice', 10, 80.1)], schema).replace(datetime.now(), datetime.now()).first() + + # should fail if provided mixed type replacements + with self.assertRaises(ValueError): + self.spark.createDataFrame( + [(u'Alice', 10, 80.1)], schema).replace(["Alice", 10], ["Eve", 20]).first() + + with self.assertRaises(ValueError): + self.spark.createDataFrame( + [(u'Alice', 10, 80.1)], schema).replace({u"Alice": u"Bob", 10: 20}).first() + def test_capture_analysis_exception(self): self.assertRaises(AnalysisException, lambda: self.spark.sql("select abc")) self.assertRaises(AnalysisException, lambda: self.df.selectExpr("a + b")) From 9543fc0e08a21680961689ea772441c49fcd52ee Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Wed, 5 Apr 2017 16:03:04 -0700 Subject: [PATCH 218/512] [SPARK-20224][SS] Updated docs for streaming dropDuplicates and mapGroupsWithState ## What changes were proposed in this pull request? - Fixed bug in Java API not passing timeout conf to scala API - Updated markdown docs - Updated scala docs - Added scala and Java example ## How was this patch tested? Manually ran examples. Author: Tathagata Das Closes #17539 from tdas/SPARK-20224. --- .../structured-streaming-programming-guide.md | 98 ++++++- .../JavaStructuredSessionization.java | 255 ++++++++++++++++++ .../streaming/StructuredSessionization.scala | 151 +++++++++++ .../spark/sql/KeyValueGroupedDataset.scala | 2 +- .../spark/sql/streaming/GroupState.scala | 15 +- 5 files changed, 509 insertions(+), 12 deletions(-) create mode 100644 examples/src/main/java/org/apache/spark/examples/sql/streaming/JavaStructuredSessionization.java create mode 100644 examples/src/main/scala/org/apache/spark/examples/sql/streaming/StructuredSessionization.scala diff --git a/docs/structured-streaming-programming-guide.md b/docs/structured-streaming-programming-guide.md index b5cf9f164498..37a1d6189a42 100644 --- a/docs/structured-streaming-programming-guide.md +++ b/docs/structured-streaming-programming-guide.md @@ -1,6 +1,6 @@ --- layout: global -displayTitle: Structured Streaming Programming Guide [Alpha] +displayTitle: Structured Streaming Programming Guide [Experimental] title: Structured Streaming Programming Guide --- @@ -871,6 +871,65 @@ streamingDf.join(staticDf, "type", "right_join") # right outer join with a stat +### Streaming Deduplication +You can deduplicate records in data streams using a unique identifier in the events. This is exactly same as deduplication on static using a unique identifier column. The query will store the necessary amount of data from previous records such that it can filter duplicate records. Similar to aggregations, you can use deduplication with or without watermarking. + +- *With watermark* - If there is a upper bound on how late a duplicate record may arrive, then you can define a watermark on a event time column and deduplicate using both the guid and the event time columns. The query will use the watermark to remove old state data from past records that are not expected to get any duplicates any more. This bounds the amount of the state the query has to maintain. + +- *Without watermark* - Since there are no bounds on when a duplicate record may arrive, the query stores the data from all the past records as state. + +
    +
    + +{% highlight scala %} +val streamingDf = spark.readStream. ... // columns: guid, eventTime, ... + +// Without watermark using guid column +streamingDf.dropDuplicates("guid") + +// With watermark using guid and eventTime columns +streamingDf + .withWatermark("eventTime", "10 seconds") + .dropDuplicates("guid", "eventTime") +{% endhighlight %} + +
    +
    + +{% highlight java %} +Dataset streamingDf = spark.readStream. ...; // columns: guid, eventTime, ... + +// Without watermark using guid column +streamingDf.dropDuplicates("guid"); + +// With watermark using guid and eventTime columns +streamingDf + .withWatermark("eventTime", "10 seconds") + .dropDuplicates("guid", "eventTime"); +{% endhighlight %} + + +
    +
    + +{% highlight python %} +streamingDf = spark.readStream. ... + +// Without watermark using guid column +streamingDf.dropDuplicates("guid") + +// With watermark using guid and eventTime columns +streamingDf \ + .withWatermark("eventTime", "10 seconds") \ + .dropDuplicates("guid", "eventTime") +{% endhighlight %} + +
    +
    + +### Arbitrary Stateful Operations +Many uscases require more advanced stateful operations than aggregations. For example, in many usecases, you have to track sessions from data streams of events. For doing such sessionization, you will have to save arbitrary types of data as state, and perform arbitrary operations on the state using the data stream events in every trigger. Since Spark 2.2, this can be done using the operation `mapGroupsWithState` and the more powerful operation `flatMapGroupsWithState`. Both operations allow you to apply user-defined code on grouped Datasets to update user-defined state. For more concrete details, take a look at the API documentation ([Scala](api/scala/index.html#org.apache.spark.sql.streaming.GroupState)/[Java](api/java/org/apache/spark/sql/streaming/GroupState.html)) and the examples ([Scala]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/scala/org/apache/spark/examples/sql/streaming/StructuredSessionization.scala)/[Java]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/java/org/apache/spark/examples/sql/streaming/JavaStructuredSessionization.java)). + ### Unsupported Operations There are a few DataFrame/Dataset operations that are not supported with streaming DataFrames/Datasets. Some of them are as follows. @@ -891,7 +950,7 @@ Some of them are as follows. + Right outer join with a streaming Dataset on the left is not supported -- Any kind of joins between two streaming Datasets are not yet supported. +- Any kind of joins between two streaming Datasets is not yet supported. In addition, there are some Dataset methods that will not work on streaming Datasets. They are actions that will immediately run queries and return results, which does not make sense on a streaming Dataset. Rather, those functionalities can be done by explicitly starting a streaming query (see the next section regarding that). @@ -951,13 +1010,6 @@ Here is the compatibility matrix.
    - - - - - @@ -986,6 +1038,33 @@ Here is the compatibility matrix. this mode. + + + + + + + + + + + + + + + + + + + + + @@ -994,6 +1073,7 @@ Here is the compatibility matrix.
    EndpointMeaning
    /applications/[app-id]/jobs A list of all jobs for a given application. -
    ?status=[complete|succeeded|failed] list only jobs in the specific state. +
    ?status=[running|succeeded|failed|unknown] list only jobs in the specific state.
    Supported Output Modes Notes
    Queries without aggregationAppend, Update - Complete mode not supported as it is infeasible to keep all data in the Result Table. -
    Queries with aggregation Aggregation on event-time with watermark
    Queries with mapGroupsWithStateUpdate
    Queries with flatMapGroupsWithStateAppend operation modeAppend + Aggregations are allowed after flatMapGroupsWithState. +
    Update operation modeUpdate + Aggregations not allowed after flatMapGroupsWithState. +
    Other queriesAppend, Update + Complete mode not supported as it is infeasible to keep all unaggregated data in the Result Table. +
    + #### Output Sinks There are a few types of built-in output sinks. diff --git a/examples/src/main/java/org/apache/spark/examples/sql/streaming/JavaStructuredSessionization.java b/examples/src/main/java/org/apache/spark/examples/sql/streaming/JavaStructuredSessionization.java new file mode 100644 index 000000000000..da3a5dfe8628 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/sql/streaming/JavaStructuredSessionization.java @@ -0,0 +1,255 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.examples.sql.streaming; + +import org.apache.spark.api.java.function.FlatMapFunction; +import org.apache.spark.api.java.function.MapFunction; +import org.apache.spark.api.java.function.MapGroupsWithStateFunction; +import org.apache.spark.sql.*; +import org.apache.spark.sql.streaming.GroupState; +import org.apache.spark.sql.streaming.GroupStateTimeout; +import org.apache.spark.sql.streaming.StreamingQuery; + +import java.io.Serializable; +import java.sql.Timestamp; +import java.util.*; + +import scala.Tuple2; + +/** + * Counts words in UTF8 encoded, '\n' delimited text received from the network. + *

    + * Usage: JavaStructuredNetworkWordCount + * and describe the TCP server that Structured Streaming + * would connect to receive data. + *

    + * To run this on your local machine, you need to first run a Netcat server + * `$ nc -lk 9999` + * and then run the example + * `$ bin/run-example sql.streaming.JavaStructuredSessionization + * localhost 9999` + */ +public final class JavaStructuredSessionization { + + public static void main(String[] args) throws Exception { + if (args.length < 2) { + System.err.println("Usage: JavaStructuredSessionization "); + System.exit(1); + } + + String host = args[0]; + int port = Integer.parseInt(args[1]); + + SparkSession spark = SparkSession + .builder() + .appName("JavaStructuredSessionization") + .getOrCreate(); + + // Create DataFrame representing the stream of input lines from connection to host:port + Dataset lines = spark + .readStream() + .format("socket") + .option("host", host) + .option("port", port) + .option("includeTimestamp", true) + .load(); + + FlatMapFunction linesToEvents = + new FlatMapFunction() { + @Override + public Iterator call(LineWithTimestamp lineWithTimestamp) throws Exception { + ArrayList eventList = new ArrayList(); + for (String word : lineWithTimestamp.getLine().split(" ")) { + eventList.add(new Event(word, lineWithTimestamp.getTimestamp())); + } + System.out.println( + "Number of events from " + lineWithTimestamp.getLine() + " = " + eventList.size()); + return eventList.iterator(); + } + }; + + // Split the lines into words, treat words as sessionId of events + Dataset events = lines + .withColumnRenamed("value", "line") + .as(Encoders.bean(LineWithTimestamp.class)) + .flatMap(linesToEvents, Encoders.bean(Event.class)); + + // Sessionize the events. Track number of events, start and end timestamps of session, and + // and report session updates. + // + // Step 1: Define the state update function + MapGroupsWithStateFunction stateUpdateFunc = + new MapGroupsWithStateFunction() { + @Override public SessionUpdate call( + String sessionId, Iterator events, GroupState state) + throws Exception { + // If timed out, then remove session and send final update + if (state.hasTimedOut()) { + SessionUpdate finalUpdate = new SessionUpdate( + sessionId, state.get().getDurationMs(), state.get().getNumEvents(), true); + state.remove(); + return finalUpdate; + + } else { + // Find max and min timestamps in events + long maxTimestampMs = Long.MIN_VALUE; + long minTimestampMs = Long.MAX_VALUE; + int numNewEvents = 0; + while (events.hasNext()) { + Event e = events.next(); + long timestampMs = e.getTimestamp().getTime(); + maxTimestampMs = Math.max(timestampMs, maxTimestampMs); + minTimestampMs = Math.min(timestampMs, minTimestampMs); + numNewEvents += 1; + } + SessionInfo updatedSession = new SessionInfo(); + + // Update start and end timestamps in session + if (state.exists()) { + SessionInfo oldSession = state.get(); + updatedSession.setNumEvents(oldSession.numEvents + numNewEvents); + updatedSession.setStartTimestampMs(oldSession.startTimestampMs); + updatedSession.setEndTimestampMs(Math.max(oldSession.endTimestampMs, maxTimestampMs)); + } else { + updatedSession.setNumEvents(numNewEvents); + updatedSession.setStartTimestampMs(minTimestampMs); + updatedSession.setEndTimestampMs(maxTimestampMs); + } + state.update(updatedSession); + // Set timeout such that the session will be expired if no data received for 10 seconds + state.setTimeoutDuration("10 seconds"); + return new SessionUpdate( + sessionId, state.get().getDurationMs(), state.get().getNumEvents(), false); + } + } + }; + + // Step 2: Apply the state update function to the events streaming Dataset grouped by sessionId + Dataset sessionUpdates = events + .groupByKey( + new MapFunction() { + @Override public String call(Event event) throws Exception { + return event.getSessionId(); + } + }, Encoders.STRING()) + .mapGroupsWithState( + stateUpdateFunc, + Encoders.bean(SessionInfo.class), + Encoders.bean(SessionUpdate.class), + GroupStateTimeout.ProcessingTimeTimeout()); + + // Start running the query that prints the session updates to the console + StreamingQuery query = sessionUpdates + .writeStream() + .outputMode("update") + .format("console") + .start(); + + query.awaitTermination(); + } + + /** + * User-defined data type representing the raw lines with timestamps. + */ + public static class LineWithTimestamp implements Serializable { + private String line; + private Timestamp timestamp; + + public Timestamp getTimestamp() { return timestamp; } + public void setTimestamp(Timestamp timestamp) { this.timestamp = timestamp; } + + public String getLine() { return line; } + public void setLine(String sessionId) { this.line = sessionId; } + } + + /** + * User-defined data type representing the input events + */ + public static class Event implements Serializable { + private String sessionId; + private Timestamp timestamp; + + public Event() { } + public Event(String sessionId, Timestamp timestamp) { + this.sessionId = sessionId; + this.timestamp = timestamp; + } + + public Timestamp getTimestamp() { return timestamp; } + public void setTimestamp(Timestamp timestamp) { this.timestamp = timestamp; } + + public String getSessionId() { return sessionId; } + public void setSessionId(String sessionId) { this.sessionId = sessionId; } + } + + /** + * User-defined data type for storing a session information as state in mapGroupsWithState. + */ + public static class SessionInfo implements Serializable { + private int numEvents = 0; + private long startTimestampMs = -1; + private long endTimestampMs = -1; + + public int getNumEvents() { return numEvents; } + public void setNumEvents(int numEvents) { this.numEvents = numEvents; } + + public long getStartTimestampMs() { return startTimestampMs; } + public void setStartTimestampMs(long startTimestampMs) { + this.startTimestampMs = startTimestampMs; + } + + public long getEndTimestampMs() { return endTimestampMs; } + public void setEndTimestampMs(long endTimestampMs) { this.endTimestampMs = endTimestampMs; } + + public long getDurationMs() { return endTimestampMs - startTimestampMs; } + @Override public String toString() { + return "SessionInfo(numEvents = " + numEvents + + ", timestamps = " + startTimestampMs + " to " + endTimestampMs + ")"; + } + } + + /** + * User-defined data type representing the update information returned by mapGroupsWithState. + */ + public static class SessionUpdate implements Serializable { + private String id; + private long durationMs; + private int numEvents; + private boolean expired; + + public SessionUpdate() { } + + public SessionUpdate(String id, long durationMs, int numEvents, boolean expired) { + this.id = id; + this.durationMs = durationMs; + this.numEvents = numEvents; + this.expired = expired; + } + + public String getId() { return id; } + public void setId(String id) { this.id = id; } + + public long getDurationMs() { return durationMs; } + public void setDurationMs(long durationMs) { this.durationMs = durationMs; } + + public int getNumEvents() { return numEvents; } + public void setNumEvents(int numEvents) { this.numEvents = numEvents; } + + public boolean isExpired() { return expired; } + public void setExpired(boolean expired) { this.expired = expired; } + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/streaming/StructuredSessionization.scala b/examples/src/main/scala/org/apache/spark/examples/sql/streaming/StructuredSessionization.scala new file mode 100644 index 000000000000..2ce792c00849 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/sql/streaming/StructuredSessionization.scala @@ -0,0 +1,151 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.sql.streaming + +import java.sql.Timestamp + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.streaming._ + + +/** + * Counts words in UTF8 encoded, '\n' delimited text received from the network. + * + * Usage: MapGroupsWithState + * and describe the TCP server that Structured Streaming + * would connect to receive data. + * + * To run this on your local machine, you need to first run a Netcat server + * `$ nc -lk 9999` + * and then run the example + * `$ bin/run-example sql.streaming.StructuredNetworkWordCount + * localhost 9999` + */ +object StructuredSessionization { + + def main(args: Array[String]): Unit = { + if (args.length < 2) { + System.err.println("Usage: StructuredNetworkWordCount ") + System.exit(1) + } + + val host = args(0) + val port = args(1).toInt + + val spark = SparkSession + .builder + .appName("StructuredSessionization") + .getOrCreate() + + import spark.implicits._ + + // Create DataFrame representing the stream of input lines from connection to host:port + val lines = spark.readStream + .format("socket") + .option("host", host) + .option("port", port) + .option("includeTimestamp", true) + .load() + + // Split the lines into words, treat words as sessionId of events + val events = lines + .as[(String, Timestamp)] + .flatMap { case (line, timestamp) => + line.split(" ").map(word => Event(sessionId = word, timestamp)) + } + + // Sessionize the events. Track number of events, start and end timestamps of session, and + // and report session updates. + val sessionUpdates = events + .groupByKey(event => event.sessionId) + .mapGroupsWithState[SessionInfo, SessionUpdate](GroupStateTimeout.ProcessingTimeTimeout) { + + case (sessionId: String, events: Iterator[Event], state: GroupState[SessionInfo]) => + + // If timed out, then remove session and send final update + if (state.hasTimedOut) { + val finalUpdate = + SessionUpdate(sessionId, state.get.durationMs, state.get.numEvents, expired = true) + state.remove() + finalUpdate + } else { + // Update start and end timestamps in session + val timestamps = events.map(_.timestamp.getTime).toSeq + val updatedSession = if (state.exists) { + val oldSession = state.get + SessionInfo( + oldSession.numEvents + timestamps.size, + oldSession.startTimestampMs, + math.max(oldSession.endTimestampMs, timestamps.max)) + } else { + SessionInfo(timestamps.size, timestamps.min, timestamps.max) + } + state.update(updatedSession) + + // Set timeout such that the session will be expired if no data received for 10 seconds + state.setTimeoutDuration("10 seconds") + SessionUpdate(sessionId, state.get.durationMs, state.get.numEvents, expired = false) + } + } + + // Start running the query that prints the session updates to the console + val query = sessionUpdates + .writeStream + .outputMode("update") + .format("console") + .start() + + query.awaitTermination() + } +} +/** User-defined data type representing the input events */ +case class Event(sessionId: String, timestamp: Timestamp) + +/** + * User-defined data type for storing a session information as state in mapGroupsWithState. + * + * @param numEvents total number of events received in the session + * @param startTimestampMs timestamp of first event received in the session when it started + * @param endTimestampMs timestamp of last event received in the session before it expired + */ +case class SessionInfo( + numEvents: Int, + startTimestampMs: Long, + endTimestampMs: Long) { + + /** Duration of the session, between the first and last events */ + def durationMs: Long = endTimestampMs - startTimestampMs +} + +/** + * User-defined data type representing the update information returned by mapGroupsWithState. + * + * @param id Id of the session + * @param durationMs Duration the session was active, that is, from first event to its expiry + * @param numEvents Number of events received by the session while it was active + * @param expired Is the session active or expired + */ +case class SessionUpdate( + id: String, + durationMs: Long, + numEvents: Int, + expired: Boolean) + +// scalastyle:on println + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala index 022c2f5629e8..cb42e9e4560c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala @@ -347,7 +347,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( stateEncoder: Encoder[S], outputEncoder: Encoder[U], timeoutConf: GroupStateTimeout): Dataset[U] = { - mapGroupsWithState[S, U]( + mapGroupsWithState[S, U](timeoutConf)( (key: K, it: Iterator[V], s: GroupState[S]) => func.call(key, it.asJava, s) )(stateEncoder, outputEncoder) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/GroupState.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/GroupState.scala index 15df906ca7b1..c659ac7fcf3d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/GroupState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/GroupState.scala @@ -34,7 +34,7 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalGroupState * `Dataset.groupByKey()`) while maintaining user-defined per-group state between invocations. * For a static batch Dataset, the function will be invoked once per group. For a streaming * Dataset, the function will be invoked for each group repeatedly in every trigger. - * That is, in every batch of the `streaming.StreamingQuery`, + * That is, in every batch of the `StreamingQuery`, * the function will be invoked once for each group that has data in the trigger. Furthermore, * if timeout is set, then the function will invoked on timed out groups (more detail below). * @@ -42,12 +42,23 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalGroupState * - The key of the group. * - An iterator containing all the values for this group. * - A user-defined state object set by previous invocations of the given function. + * * In case of a batch Dataset, there is only one invocation and state object will be empty as * there is no prior state. Essentially, for batch Datasets, `[map/flatMap]GroupsWithState` * is equivalent to `[map/flatMap]Groups` and any updates to the state and/or timeouts have * no effect. * - * Important points to note about the function. + * The major difference between `mapGroupsWithState` and `flatMapGroupsWithState` is that the + * former allows the function to return one and only one record, whereas the latter + * allows the function to return any number of records (including no records). Furthermore, the + * `flatMapGroupsWithState` is associated with an operation output mode, which can be either + * `Append` or `Update`. Semantically, this defines whether the output records of one trigger + * is effectively replacing the previously output records (from previous triggers) or is appending + * to the list of previously output records. Essentially, this defines how the Result Table (refer + * to the semantics in the programming guide) is updated, and allows us to reason about the + * semantics of later operations. + * + * Important points to note about the function (both mapGroupsWithState and flatMapGroupsWithState). * - In a trigger, the function will be called only the groups present in the batch. So do not * assume that the function will be called in every trigger for every group that has state. * - There is no guaranteed ordering of values in the iterator in the function, neither with From 9d68c67235481fa33983afb766916b791ca8212a Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Thu, 6 Apr 2017 08:33:14 +0800 Subject: [PATCH 219/512] [SPARK-20204][SQL][FOLLOWUP] SQLConf should react to change in default timezone settings ## What changes were proposed in this pull request? Make sure SESSION_LOCAL_TIMEZONE reflects the change in JVM's default timezone setting. Currently several timezone related tests fail as the change to default timezone is not picked up by SQLConf. ## How was this patch tested? Added an unit test in ConfigEntrySuite Author: Dilip Biswal Closes #17537 from dilipbiswal/timezone_debug. --- .../spark/internal/config/ConfigBuilder.scala | 8 ++++++++ .../spark/internal/config/ConfigEntry.scala | 17 +++++++++++++++++ .../internal/config/ConfigEntrySuite.scala | 9 +++++++++ .../org/apache/spark/sql/internal/SQLConf.scala | 2 +- 4 files changed, 35 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala b/core/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala index b9921138cc6c..e5d60a7ef098 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala @@ -147,6 +147,14 @@ private[spark] class TypedConfigBuilder[T]( } } + /** Creates a [[ConfigEntry]] with a function to determine the default value */ + def createWithDefaultFunction(defaultFunc: () => T): ConfigEntry[T] = { + val entry = new ConfigEntryWithDefaultFunction[T](parent.key, defaultFunc, converter, + stringConverter, parent._doc, parent._public) + parent._onCreate.foreach(_ (entry)) + entry + } + /** * Creates a [[ConfigEntry]] that has a default value. The default value is provided as a * [[String]] and must be a valid value for the entry. diff --git a/core/src/main/scala/org/apache/spark/internal/config/ConfigEntry.scala b/core/src/main/scala/org/apache/spark/internal/config/ConfigEntry.scala index 4f3e42bb3c94..e86712e84d6a 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/ConfigEntry.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/ConfigEntry.scala @@ -78,7 +78,24 @@ private class ConfigEntryWithDefault[T] ( def readFrom(reader: ConfigReader): T = { reader.get(key).map(valueConverter).getOrElse(_defaultValue) } +} + +private class ConfigEntryWithDefaultFunction[T] ( + key: String, + _defaultFunction: () => T, + valueConverter: String => T, + stringConverter: T => String, + doc: String, + isPublic: Boolean) + extends ConfigEntry(key, valueConverter, stringConverter, doc, isPublic) { + + override def defaultValue: Option[T] = Some(_defaultFunction()) + override def defaultValueString: String = stringConverter(_defaultFunction()) + + def readFrom(reader: ConfigReader): T = { + reader.get(key).map(valueConverter).getOrElse(_defaultFunction()) + } } private class ConfigEntryWithDefaultString[T] ( diff --git a/core/src/test/scala/org/apache/spark/internal/config/ConfigEntrySuite.scala b/core/src/test/scala/org/apache/spark/internal/config/ConfigEntrySuite.scala index 3ff7e84d73bd..e2ba0d2a53d0 100644 --- a/core/src/test/scala/org/apache/spark/internal/config/ConfigEntrySuite.scala +++ b/core/src/test/scala/org/apache/spark/internal/config/ConfigEntrySuite.scala @@ -251,4 +251,13 @@ class ConfigEntrySuite extends SparkFunSuite { .createWithDefault(null) testEntryRef(nullConf, ref(nullConf)) } + + test("conf entry : default function") { + var data = 0 + val conf = new SparkConf() + val iConf = ConfigBuilder(testKey("intval")).intConf.createWithDefaultFunction(() => data) + assert(conf.get(iConf) === 0) + data = 2 + assert(conf.get(iConf) === 2) + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 5b5d547f8fe5..e685c2bed50a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -752,7 +752,7 @@ object SQLConf { buildConf("spark.sql.session.timeZone") .doc("""The ID of session local timezone, e.g. "GMT", "America/Los_Angeles", etc.""") .stringConf - .createWithDefault(TimeZone.getDefault().getID()) + .createWithDefaultFunction(() => TimeZone.getDefault.getID) val WINDOW_EXEC_BUFFER_SPILL_THRESHOLD = buildConf("spark.sql.windowExec.buffer.spill.threshold") From 12206058e8780e202c208b92774df3773eff36ae Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 5 Apr 2017 17:46:44 -0700 Subject: [PATCH 220/512] [SPARK-20214][ML] Make sure converted csc matrix has sorted indices ## What changes were proposed in this pull request? `_convert_to_vector` converts a scipy sparse matrix to csc matrix for initializing `SparseVector`. However, it doesn't guarantee the converted csc matrix has sorted indices and so a failure happens when you do something like that: from scipy.sparse import lil_matrix lil = lil_matrix((4, 1)) lil[1, 0] = 1 lil[3, 0] = 2 _convert_to_vector(lil.todok()) File "/home/jenkins/workspace/python/pyspark/mllib/linalg/__init__.py", line 78, in _convert_to_vector return SparseVector(l.shape[0], csc.indices, csc.data) File "/home/jenkins/workspace/python/pyspark/mllib/linalg/__init__.py", line 556, in __init__ % (self.indices[i], self.indices[i + 1])) TypeError: Indices 3 and 1 are not strictly increasing A simple test can confirm that `dok_matrix.tocsc()` won't guarantee sorted indices: >>> from scipy.sparse import lil_matrix >>> lil = lil_matrix((4, 1)) >>> lil[1, 0] = 1 >>> lil[3, 0] = 2 >>> dok = lil.todok() >>> csc = dok.tocsc() >>> csc.has_sorted_indices 0 >>> csc.indices array([3, 1], dtype=int32) I checked the source codes of scipy. The only way to guarantee it is `csc_matrix.tocsr()` and `csr_matrix.tocsc()`. ## How was this patch tested? Existing tests. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Liang-Chi Hsieh Closes #17532 from viirya/make-sure-sorted-indices. --- python/pyspark/ml/linalg/__init__.py | 3 +++ python/pyspark/mllib/linalg/__init__.py | 3 +++ python/pyspark/mllib/tests.py | 11 +++++++++++ 3 files changed, 17 insertions(+) diff --git a/python/pyspark/ml/linalg/__init__.py b/python/pyspark/ml/linalg/__init__.py index b76534325196..ad1b487676fa 100644 --- a/python/pyspark/ml/linalg/__init__.py +++ b/python/pyspark/ml/linalg/__init__.py @@ -72,7 +72,10 @@ def _convert_to_vector(l): return DenseVector(l) elif _have_scipy and scipy.sparse.issparse(l): assert l.shape[1] == 1, "Expected column vector" + # Make sure the converted csc_matrix has sorted indices. csc = l.tocsc() + if not csc.has_sorted_indices: + csc.sort_indices() return SparseVector(l.shape[0], csc.indices, csc.data) else: raise TypeError("Cannot convert type %s into Vector" % type(l)) diff --git a/python/pyspark/mllib/linalg/__init__.py b/python/pyspark/mllib/linalg/__init__.py index 031f22c02098..7b24b3c74a9f 100644 --- a/python/pyspark/mllib/linalg/__init__.py +++ b/python/pyspark/mllib/linalg/__init__.py @@ -74,7 +74,10 @@ def _convert_to_vector(l): return DenseVector(l) elif _have_scipy and scipy.sparse.issparse(l): assert l.shape[1] == 1, "Expected column vector" + # Make sure the converted csc_matrix has sorted indices. csc = l.tocsc() + if not csc.has_sorted_indices: + csc.sort_indices() return SparseVector(l.shape[0], csc.indices, csc.data) else: raise TypeError("Cannot convert type %s into Vector" % type(l)) diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index c519883cdd73..523b3f111331 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -853,6 +853,17 @@ def serialize(l): self.assertEqual(sv, serialize(lil.tocsr())) self.assertEqual(sv, serialize(lil.todok())) + def test_convert_to_vector(self): + from scipy.sparse import csc_matrix + # Create a CSC matrix with non-sorted indices + indptr = array([0, 2]) + indices = array([3, 1]) + data = array([2.0, 1.0]) + csc = csc_matrix((data, indices, indptr)) + self.assertFalse(csc.has_sorted_indices) + sv = SparseVector(4, {1: 1, 3: 2}) + self.assertEqual(sv, _convert_to_vector(csc)) + def test_dot(self): from scipy.sparse import lil_matrix lil = lil_matrix((4, 1)) From 4000f128b7101484ba618115504ca916c22fa84a Mon Sep 17 00:00:00 2001 From: Ioana Delaney Date: Wed, 5 Apr 2017 18:02:53 -0700 Subject: [PATCH 221/512] [SPARK-20231][SQL] Refactor star schema code for the subsequent star join detection in CBO ## What changes were proposed in this pull request? This commit moves star schema code from ```join.scala``` to ```StarSchemaDetection.scala```. It also applies some minor fixes in ```StarJoinReorderSuite.scala```. ## How was this patch tested? Run existing ```StarJoinReorderSuite.scala```. Author: Ioana Delaney Closes #17544 from ioana-delaney/starSchemaCBOv2. --- .../optimizer/StarSchemaDetection.scala | 351 ++++++++++++++++++ .../spark/sql/catalyst/optimizer/joins.scala | 328 +--------------- .../optimizer/StarJoinReorderSuite.scala | 4 +- 3 files changed, 354 insertions(+), 329 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/StarSchemaDetection.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/StarSchemaDetection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/StarSchemaDetection.scala new file mode 100644 index 000000000000..91cb004eaec4 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/StarSchemaDetection.scala @@ -0,0 +1,351 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import scala.annotation.tailrec + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.planning.PhysicalOperation +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.internal.SQLConf + +/** + * Encapsulates star-schema detection logic. + */ +case class StarSchemaDetection(conf: SQLConf) extends PredicateHelper { + + /** + * Star schema consists of one or more fact tables referencing a number of dimension + * tables. In general, star-schema joins are detected using the following conditions: + * 1. Informational RI constraints (reliable detection) + * + Dimension contains a primary key that is being joined to the fact table. + * + Fact table contains foreign keys referencing multiple dimension tables. + * 2. Cardinality based heuristics + * + Usually, the table with the highest cardinality is the fact table. + * + Table being joined with the most number of tables is the fact table. + * + * To detect star joins, the algorithm uses a combination of the above two conditions. + * The fact table is chosen based on the cardinality heuristics, and the dimension + * tables are chosen based on the RI constraints. A star join will consist of the largest + * fact table joined with the dimension tables on their primary keys. To detect that a + * column is a primary key, the algorithm uses table and column statistics. + * + * The algorithm currently returns only the star join with the largest fact table. + * Choosing the largest fact table on the driving arm to avoid large inners is in + * general a good heuristic. This restriction will be lifted to observe multiple + * star joins. + * + * The highlights of the algorithm are the following: + * + * Given a set of joined tables/plans, the algorithm first verifies if they are eligible + * for star join detection. An eligible plan is a base table access with valid statistics. + * A base table access represents Project or Filter operators above a LeafNode. Conservatively, + * the algorithm only considers base table access as part of a star join since they provide + * reliable statistics. This restriction can be lifted with the CBO enablement by default. + * + * If some of the plans are not base table access, or statistics are not available, the algorithm + * returns an empty star join plan since, in the absence of statistics, it cannot make + * good planning decisions. Otherwise, the algorithm finds the table with the largest cardinality + * (number of rows), which is assumed to be a fact table. + * + * Next, it computes the set of dimension tables for the current fact table. A dimension table + * is assumed to be in a RI relationship with a fact table. To infer column uniqueness, + * the algorithm compares the number of distinct values with the total number of rows in the + * table. If their relative difference is within certain limits (i.e. ndvMaxError * 2, adjusted + * based on 1TB TPC-DS data), the column is assumed to be unique. + */ + def findStarJoins( + input: Seq[LogicalPlan], + conditions: Seq[Expression]): Seq[LogicalPlan] = { + + val emptyStarJoinPlan = Seq.empty[LogicalPlan] + + if (!conf.starSchemaDetection || input.size < 2) { + emptyStarJoinPlan + } else { + // Find if the input plans are eligible for star join detection. + // An eligible plan is a base table access with valid statistics. + val foundEligibleJoin = input.forall { + case PhysicalOperation(_, _, t: LeafNode) if t.stats(conf).rowCount.isDefined => true + case _ => false + } + + if (!foundEligibleJoin) { + // Some plans don't have stats or are complex plans. Conservatively, + // return an empty star join. This restriction can be lifted + // once statistics are propagated in the plan. + emptyStarJoinPlan + } else { + // Find the fact table using cardinality based heuristics i.e. + // the table with the largest number of rows. + val sortedFactTables = input.map { plan => + TableAccessCardinality(plan, getTableAccessCardinality(plan)) + }.collect { case t @ TableAccessCardinality(_, Some(_)) => + t + }.sortBy(_.size)(implicitly[Ordering[Option[BigInt]]].reverse) + + sortedFactTables match { + case Nil => + emptyStarJoinPlan + case table1 :: table2 :: _ + if table2.size.get.toDouble > conf.starSchemaFTRatio * table1.size.get.toDouble => + // If the top largest tables have comparable number of rows, return an empty star plan. + // This restriction will be lifted when the algorithm is generalized + // to return multiple star plans. + emptyStarJoinPlan + case TableAccessCardinality(factTable, _) :: rest => + // Find the fact table joins. + val allFactJoins = rest.collect { case TableAccessCardinality(plan, _) + if findJoinConditions(factTable, plan, conditions).nonEmpty => + plan + } + + // Find the corresponding join conditions. + val allFactJoinCond = allFactJoins.flatMap { plan => + val joinCond = findJoinConditions(factTable, plan, conditions) + joinCond + } + + // Verify if the join columns have valid statistics. + // Allow any relational comparison between the tables. Later + // we will heuristically choose a subset of equi-join + // tables. + val areStatsAvailable = allFactJoins.forall { dimTable => + allFactJoinCond.exists { + case BinaryComparison(lhs: AttributeReference, rhs: AttributeReference) => + val dimCol = if (dimTable.outputSet.contains(lhs)) lhs else rhs + val factCol = if (factTable.outputSet.contains(lhs)) lhs else rhs + hasStatistics(dimCol, dimTable) && hasStatistics(factCol, factTable) + case _ => false + } + } + + if (!areStatsAvailable) { + emptyStarJoinPlan + } else { + // Find the subset of dimension tables. A dimension table is assumed to be in a + // RI relationship with the fact table. Only consider equi-joins + // between a fact and a dimension table to avoid expanding joins. + val eligibleDimPlans = allFactJoins.filter { dimTable => + allFactJoinCond.exists { + case cond @ Equality(lhs: AttributeReference, rhs: AttributeReference) => + val dimCol = if (dimTable.outputSet.contains(lhs)) lhs else rhs + isUnique(dimCol, dimTable) + case _ => false + } + } + + if (eligibleDimPlans.isEmpty || eligibleDimPlans.size < 2) { + // An eligible star join was not found since the join is not + // an RI join, or the star join is an expanding join. + // Also, a star would involve more than one dimension table. + emptyStarJoinPlan + } else { + factTable +: eligibleDimPlans + } + } + } + } + } + } + + /** + * Determines if a column referenced by a base table access is a primary key. + * A column is a PK if it is not nullable and has unique values. + * To determine if a column has unique values in the absence of informational + * RI constraints, the number of distinct values is compared to the total + * number of rows in the table. If their relative difference + * is within the expected limits (i.e. 2 * spark.sql.statistics.ndv.maxError based + * on TPC-DS data results), the column is assumed to have unique values. + */ + private def isUnique( + column: Attribute, + plan: LogicalPlan): Boolean = plan match { + case PhysicalOperation(_, _, t: LeafNode) => + val leafCol = findLeafNodeCol(column, plan) + leafCol match { + case Some(col) if t.outputSet.contains(col) => + val stats = t.stats(conf) + stats.rowCount match { + case Some(rowCount) if rowCount >= 0 => + if (stats.attributeStats.nonEmpty && stats.attributeStats.contains(col)) { + val colStats = stats.attributeStats.get(col) + if (colStats.get.nullCount > 0) { + false + } else { + val distinctCount = colStats.get.distinctCount + val relDiff = math.abs((distinctCount.toDouble / rowCount.toDouble) - 1.0d) + // ndvMaxErr adjusted based on TPCDS 1TB data results + relDiff <= conf.ndvMaxError * 2 + } + } else { + false + } + case None => false + } + case None => false + } + case _ => false + } + + /** + * Given a column over a base table access, it returns + * the leaf node column from which the input column is derived. + */ + @tailrec + private def findLeafNodeCol( + column: Attribute, + plan: LogicalPlan): Option[Attribute] = plan match { + case pl @ PhysicalOperation(_, _, _: LeafNode) => + pl match { + case t: LeafNode if t.outputSet.contains(column) => + Option(column) + case p: Project if p.outputSet.exists(_.semanticEquals(column)) => + val col = p.outputSet.find(_.semanticEquals(column)).get + findLeafNodeCol(col, p.child) + case f: Filter => + findLeafNodeCol(column, f.child) + case _ => None + } + case _ => None + } + + /** + * Checks if a column has statistics. + * The column is assumed to be over a base table access. + */ + private def hasStatistics( + column: Attribute, + plan: LogicalPlan): Boolean = plan match { + case PhysicalOperation(_, _, t: LeafNode) => + val leafCol = findLeafNodeCol(column, plan) + leafCol match { + case Some(col) if t.outputSet.contains(col) => + val stats = t.stats(conf) + stats.attributeStats.nonEmpty && stats.attributeStats.contains(col) + case None => false + } + case _ => false + } + + /** + * Returns the join predicates between two input plans. It only + * considers basic comparison operators. + */ + @inline + private def findJoinConditions( + plan1: LogicalPlan, + plan2: LogicalPlan, + conditions: Seq[Expression]): Seq[Expression] = { + val refs = plan1.outputSet ++ plan2.outputSet + conditions.filter { + case BinaryComparison(_, _) => true + case _ => false + }.filterNot(canEvaluate(_, plan1)) + .filterNot(canEvaluate(_, plan2)) + .filter(_.references.subsetOf(refs)) + } + + /** + * Checks if a star join is a selective join. A star join is assumed + * to be selective if there are local predicates on the dimension + * tables. + */ + private def isSelectiveStarJoin( + dimTables: Seq[LogicalPlan], + conditions: Seq[Expression]): Boolean = dimTables.exists { + case plan @ PhysicalOperation(_, p, _: LeafNode) => + // Checks if any condition applies to the dimension tables. + // Exclude the IsNotNull predicates until predicate selectivity is available. + // In most cases, this predicate is artificially introduced by the Optimizer + // to enforce nullability constraints. + val localPredicates = conditions.filterNot(_.isInstanceOf[IsNotNull]) + .exists(canEvaluate(_, plan)) + + // Checks if there are any predicates pushed down to the base table access. + val pushedDownPredicates = p.nonEmpty && !p.forall(_.isInstanceOf[IsNotNull]) + + localPredicates || pushedDownPredicates + case _ => false + } + + /** + * Helper case class to hold (plan, rowCount) pairs. + */ + private case class TableAccessCardinality(plan: LogicalPlan, size: Option[BigInt]) + + /** + * Returns the cardinality of a base table access. A base table access represents + * a LeafNode, or Project or Filter operators above a LeafNode. + */ + private def getTableAccessCardinality( + input: LogicalPlan): Option[BigInt] = input match { + case PhysicalOperation(_, cond, t: LeafNode) if t.stats(conf).rowCount.isDefined => + if (conf.cboEnabled && input.stats(conf).rowCount.isDefined) { + Option(input.stats(conf).rowCount.get) + } else { + Option(t.stats(conf).rowCount.get) + } + case _ => None + } + + /** + * Reorders a star join based on heuristics. It is called from ReorderJoin if CBO is disabled. + * 1) Finds the star join with the largest fact table. + * 2) Places the fact table the driving arm of the left-deep tree. + * This plan avoids large table access on the inner, and thus favor hash joins. + * 3) Applies the most selective dimensions early in the plan to reduce the amount of + * data flow. + */ + def reorderStarJoins( + input: Seq[(LogicalPlan, InnerLike)], + conditions: Seq[Expression]): Seq[(LogicalPlan, InnerLike)] = { + assert(input.size >= 2) + + val emptyStarJoinPlan = Seq.empty[(LogicalPlan, InnerLike)] + + // Find the eligible star plans. Currently, it only returns + // the star join with the largest fact table. + val eligibleJoins = input.collect{ case (plan, Inner) => plan } + val starPlan = findStarJoins(eligibleJoins, conditions) + + if (starPlan.isEmpty) { + emptyStarJoinPlan + } else { + val (factTable, dimTables) = (starPlan.head, starPlan.tail) + + // Only consider selective joins. This case is detected by observing local predicates + // on the dimension tables. In a star schema relationship, the join between the fact and the + // dimension table is a FK-PK join. Heuristically, a selective dimension may reduce + // the result of a join. + if (isSelectiveStarJoin(dimTables, conditions)) { + val reorderDimTables = dimTables.map { plan => + TableAccessCardinality(plan, getTableAccessCardinality(plan)) + }.sortBy(_.size).map { + case TableAccessCardinality(p1, _) => p1 + } + + val reorderStarPlan = factTable +: reorderDimTables + reorderStarPlan.map(plan => (plan, Inner)) + } else { + emptyStarJoinPlan + } + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala index 250dd07a16eb..c3ab58744953 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala @@ -20,338 +20,12 @@ package org.apache.spark.sql.catalyst.optimizer import scala.annotation.tailrec import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.planning.{ExtractFiltersAndInnerJoins, PhysicalOperation} +import org.apache.spark.sql.catalyst.planning.ExtractFiltersAndInnerJoins import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.internal.SQLConf -/** - * Encapsulates star-schema join detection. - */ -case class StarSchemaDetection(conf: SQLConf) extends PredicateHelper { - - /** - * Star schema consists of one or more fact tables referencing a number of dimension - * tables. In general, star-schema joins are detected using the following conditions: - * 1. Informational RI constraints (reliable detection) - * + Dimension contains a primary key that is being joined to the fact table. - * + Fact table contains foreign keys referencing multiple dimension tables. - * 2. Cardinality based heuristics - * + Usually, the table with the highest cardinality is the fact table. - * + Table being joined with the most number of tables is the fact table. - * - * To detect star joins, the algorithm uses a combination of the above two conditions. - * The fact table is chosen based on the cardinality heuristics, and the dimension - * tables are chosen based on the RI constraints. A star join will consist of the largest - * fact table joined with the dimension tables on their primary keys. To detect that a - * column is a primary key, the algorithm uses table and column statistics. - * - * Since Catalyst only supports left-deep tree plans, the algorithm currently returns only - * the star join with the largest fact table. Choosing the largest fact table on the - * driving arm to avoid large inners is in general a good heuristic. This restriction can - * be lifted with support for bushy tree plans. - * - * The highlights of the algorithm are the following: - * - * Given a set of joined tables/plans, the algorithm first verifies if they are eligible - * for star join detection. An eligible plan is a base table access with valid statistics. - * A base table access represents Project or Filter operators above a LeafNode. Conservatively, - * the algorithm only considers base table access as part of a star join since they provide - * reliable statistics. - * - * If some of the plans are not base table access, or statistics are not available, the algorithm - * returns an empty star join plan since, in the absence of statistics, it cannot make - * good planning decisions. Otherwise, the algorithm finds the table with the largest cardinality - * (number of rows), which is assumed to be a fact table. - * - * Next, it computes the set of dimension tables for the current fact table. A dimension table - * is assumed to be in a RI relationship with a fact table. To infer column uniqueness, - * the algorithm compares the number of distinct values with the total number of rows in the - * table. If their relative difference is within certain limits (i.e. ndvMaxError * 2, adjusted - * based on 1TB TPC-DS data), the column is assumed to be unique. - */ - def findStarJoins( - input: Seq[LogicalPlan], - conditions: Seq[Expression]): Seq[Seq[LogicalPlan]] = { - - val emptyStarJoinPlan = Seq.empty[Seq[LogicalPlan]] - - if (!conf.starSchemaDetection || input.size < 2) { - emptyStarJoinPlan - } else { - // Find if the input plans are eligible for star join detection. - // An eligible plan is a base table access with valid statistics. - val foundEligibleJoin = input.forall { - case PhysicalOperation(_, _, t: LeafNode) if t.stats(conf).rowCount.isDefined => true - case _ => false - } - - if (!foundEligibleJoin) { - // Some plans don't have stats or are complex plans. Conservatively, - // return an empty star join. This restriction can be lifted - // once statistics are propagated in the plan. - emptyStarJoinPlan - } else { - // Find the fact table using cardinality based heuristics i.e. - // the table with the largest number of rows. - val sortedFactTables = input.map { plan => - TableAccessCardinality(plan, getTableAccessCardinality(plan)) - }.collect { case t @ TableAccessCardinality(_, Some(_)) => - t - }.sortBy(_.size)(implicitly[Ordering[Option[BigInt]]].reverse) - - sortedFactTables match { - case Nil => - emptyStarJoinPlan - case table1 :: table2 :: _ - if table2.size.get.toDouble > conf.starSchemaFTRatio * table1.size.get.toDouble => - // If the top largest tables have comparable number of rows, return an empty star plan. - // This restriction will be lifted when the algorithm is generalized - // to return multiple star plans. - emptyStarJoinPlan - case TableAccessCardinality(factTable, _) :: rest => - // Find the fact table joins. - val allFactJoins = rest.collect { case TableAccessCardinality(plan, _) - if findJoinConditions(factTable, plan, conditions).nonEmpty => - plan - } - - // Find the corresponding join conditions. - val allFactJoinCond = allFactJoins.flatMap { plan => - val joinCond = findJoinConditions(factTable, plan, conditions) - joinCond - } - - // Verify if the join columns have valid statistics. - // Allow any relational comparison between the tables. Later - // we will heuristically choose a subset of equi-join - // tables. - val areStatsAvailable = allFactJoins.forall { dimTable => - allFactJoinCond.exists { - case BinaryComparison(lhs: AttributeReference, rhs: AttributeReference) => - val dimCol = if (dimTable.outputSet.contains(lhs)) lhs else rhs - val factCol = if (factTable.outputSet.contains(lhs)) lhs else rhs - hasStatistics(dimCol, dimTable) && hasStatistics(factCol, factTable) - case _ => false - } - } - - if (!areStatsAvailable) { - emptyStarJoinPlan - } else { - // Find the subset of dimension tables. A dimension table is assumed to be in a - // RI relationship with the fact table. Only consider equi-joins - // between a fact and a dimension table to avoid expanding joins. - val eligibleDimPlans = allFactJoins.filter { dimTable => - allFactJoinCond.exists { - case cond @ Equality(lhs: AttributeReference, rhs: AttributeReference) => - val dimCol = if (dimTable.outputSet.contains(lhs)) lhs else rhs - isUnique(dimCol, dimTable) - case _ => false - } - } - - if (eligibleDimPlans.isEmpty) { - // An eligible star join was not found because the join is not - // an RI join, or the star join is an expanding join. - emptyStarJoinPlan - } else { - Seq(factTable +: eligibleDimPlans) - } - } - } - } - } - } - - /** - * Reorders a star join based on heuristics: - * 1) Finds the star join with the largest fact table and places it on the driving - * arm of the left-deep tree. This plan avoids large table access on the inner, and - * thus favor hash joins. - * 2) Applies the most selective dimensions early in the plan to reduce the amount of - * data flow. - */ - def reorderStarJoins( - input: Seq[(LogicalPlan, InnerLike)], - conditions: Seq[Expression]): Seq[(LogicalPlan, InnerLike)] = { - assert(input.size >= 2) - - val emptyStarJoinPlan = Seq.empty[(LogicalPlan, InnerLike)] - - // Find the eligible star plans. Currently, it only returns - // the star join with the largest fact table. - val eligibleJoins = input.collect{ case (plan, Inner) => plan } - val starPlans = findStarJoins(eligibleJoins, conditions) - - if (starPlans.isEmpty) { - emptyStarJoinPlan - } else { - val starPlan = starPlans.head - val (factTable, dimTables) = (starPlan.head, starPlan.tail) - - // Only consider selective joins. This case is detected by observing local predicates - // on the dimension tables. In a star schema relationship, the join between the fact and the - // dimension table is a FK-PK join. Heuristically, a selective dimension may reduce - // the result of a join. - // Also, conservatively assume that a fact table is joined with more than one dimension. - if (dimTables.size >= 2 && isSelectiveStarJoin(dimTables, conditions)) { - val reorderDimTables = dimTables.map { plan => - TableAccessCardinality(plan, getTableAccessCardinality(plan)) - }.sortBy(_.size).map { - case TableAccessCardinality(p1, _) => p1 - } - - val reorderStarPlan = factTable +: reorderDimTables - reorderStarPlan.map(plan => (plan, Inner)) - } else { - emptyStarJoinPlan - } - } - } - - /** - * Determines if a column referenced by a base table access is a primary key. - * A column is a PK if it is not nullable and has unique values. - * To determine if a column has unique values in the absence of informational - * RI constraints, the number of distinct values is compared to the total - * number of rows in the table. If their relative difference - * is within the expected limits (i.e. 2 * spark.sql.statistics.ndv.maxError based - * on TPCDS data results), the column is assumed to have unique values. - */ - private def isUnique( - column: Attribute, - plan: LogicalPlan): Boolean = plan match { - case PhysicalOperation(_, _, t: LeafNode) => - val leafCol = findLeafNodeCol(column, plan) - leafCol match { - case Some(col) if t.outputSet.contains(col) => - val stats = t.stats(conf) - stats.rowCount match { - case Some(rowCount) if rowCount >= 0 => - if (stats.attributeStats.nonEmpty && stats.attributeStats.contains(col)) { - val colStats = stats.attributeStats.get(col) - if (colStats.get.nullCount > 0) { - false - } else { - val distinctCount = colStats.get.distinctCount - val relDiff = math.abs((distinctCount.toDouble / rowCount.toDouble) - 1.0d) - // ndvMaxErr adjusted based on TPCDS 1TB data results - relDiff <= conf.ndvMaxError * 2 - } - } else { - false - } - case None => false - } - case None => false - } - case _ => false - } - - /** - * Given a column over a base table access, it returns - * the leaf node column from which the input column is derived. - */ - @tailrec - private def findLeafNodeCol( - column: Attribute, - plan: LogicalPlan): Option[Attribute] = plan match { - case pl @ PhysicalOperation(_, _, _: LeafNode) => - pl match { - case t: LeafNode if t.outputSet.contains(column) => - Option(column) - case p: Project if p.outputSet.exists(_.semanticEquals(column)) => - val col = p.outputSet.find(_.semanticEquals(column)).get - findLeafNodeCol(col, p.child) - case f: Filter => - findLeafNodeCol(column, f.child) - case _ => None - } - case _ => None - } - - /** - * Checks if a column has statistics. - * The column is assumed to be over a base table access. - */ - private def hasStatistics( - column: Attribute, - plan: LogicalPlan): Boolean = plan match { - case PhysicalOperation(_, _, t: LeafNode) => - val leafCol = findLeafNodeCol(column, plan) - leafCol match { - case Some(col) if t.outputSet.contains(col) => - val stats = t.stats(conf) - stats.attributeStats.nonEmpty && stats.attributeStats.contains(col) - case None => false - } - case _ => false - } - - /** - * Returns the join predicates between two input plans. It only - * considers basic comparison operators. - */ - @inline - private def findJoinConditions( - plan1: LogicalPlan, - plan2: LogicalPlan, - conditions: Seq[Expression]): Seq[Expression] = { - val refs = plan1.outputSet ++ plan2.outputSet - conditions.filter { - case BinaryComparison(_, _) => true - case _ => false - }.filterNot(canEvaluate(_, plan1)) - .filterNot(canEvaluate(_, plan2)) - .filter(_.references.subsetOf(refs)) - } - - /** - * Checks if a star join is a selective join. A star join is assumed - * to be selective if there are local predicates on the dimension - * tables. - */ - private def isSelectiveStarJoin( - dimTables: Seq[LogicalPlan], - conditions: Seq[Expression]): Boolean = dimTables.exists { - case plan @ PhysicalOperation(_, p, _: LeafNode) => - // Checks if any condition applies to the dimension tables. - // Exclude the IsNotNull predicates until predicate selectivity is available. - // In most cases, this predicate is artificially introduced by the Optimizer - // to enforce nullability constraints. - val localPredicates = conditions.filterNot(_.isInstanceOf[IsNotNull]) - .exists(canEvaluate(_, plan)) - - // Checks if there are any predicates pushed down to the base table access. - val pushedDownPredicates = p.nonEmpty && !p.forall(_.isInstanceOf[IsNotNull]) - - localPredicates || pushedDownPredicates - case _ => false - } - - /** - * Helper case class to hold (plan, rowCount) pairs. - */ - private case class TableAccessCardinality(plan: LogicalPlan, size: Option[BigInt]) - - /** - * Returns the cardinality of a base table access. A base table access represents - * a LeafNode, or Project or Filter operators above a LeafNode. - */ - private def getTableAccessCardinality( - input: LogicalPlan): Option[BigInt] = input match { - case PhysicalOperation(_, cond, t: LeafNode) if t.stats(conf).rowCount.isDefined => - if (conf.cboEnabled && input.stats(conf).rowCount.isDefined) { - Option(input.stats(conf).rowCount.get) - } else { - Option(t.stats(conf).rowCount.get) - } - case _ => None - } -} - /** * Reorder the joins and push all the conditions into join, so that the bottom ones have at least * one condition. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinReorderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinReorderSuite.scala index 003ce49eaf8e..605c01b7220d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinReorderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinReorderSuite.scala @@ -206,7 +206,7 @@ class StarJoinReorderSuite extends PlanTest with StatsEstimationTestBase { // and d3_fk1 = s3_pk1 // // Default join reordering: d1, f1, d2, d3, s3 - // Star join reordering: f1, d1, d3, d2,, d3 + // Star join reordering: f1, d1, d3, d2, s3 val query = d1.join(f1).join(d2).join(s3).join(d3) @@ -242,7 +242,7 @@ class StarJoinReorderSuite extends PlanTest with StatsEstimationTestBase { // and d3_fk1 = s3_pk1 // // Default join reordering: d1, f1, d2, d3, s3 - // Star join reordering: f1, d1, d3, d2, d3 + // Star join reordering: f1, d1, d3, d2, s3 val query = d1.join(f1).join(d2).join(s3).join(d3) .where((nameToAttr("f1_fk1") === nameToAttr("d1_pk1")) && From 5142e5d4e09c7cb36cf1d792934a21c5305c6d42 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Wed, 5 Apr 2017 19:37:21 -0700 Subject: [PATCH 222/512] [SPARK-20217][CORE] Executor should not fail stage if killed task throws non-interrupted exception ## What changes were proposed in this pull request? If tasks throw non-interrupted exceptions on kill (e.g. java.nio.channels.ClosedByInterruptException), their death is reported back as TaskFailed instead of TaskKilled. This causes stage failure in some cases. This is reproducible as follows. Run the following, and then use SparkContext.killTaskAttempt to kill one of the tasks. The entire stage will fail since we threw a RuntimeException instead of InterruptedException. ``` spark.range(100).repartition(100).foreach { i => try { Thread.sleep(10000000) } catch { case t: InterruptedException => throw new RuntimeException(t) } } ``` Based on the code in TaskSetManager, I think this also affects kills of speculative tasks. However, since the number of speculated tasks is few, and usually you need to fail a task a few times before the stage is cancelled, it unlikely this would be noticed in production unless both speculation was enabled and the num allowed task failures was = 1. We should probably unconditionally return TaskKilled instead of TaskFailed if the task was killed by the driver, regardless of the actual exception thrown. ## How was this patch tested? Unit test. The test fails before the change in Executor.scala cc JoshRosen Author: Eric Liang Closes #17531 from ericl/fix-task-interrupt. --- .../main/scala/org/apache/spark/executor/Executor.scala | 2 +- .../test/scala/org/apache/spark/SparkContextSuite.scala | 8 +++++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 99b1608010dd..83469c5ff060 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -432,7 +432,7 @@ private[spark] class Executor( setTaskFinishedAndClearInterruptStatus() execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled(t.reason))) - case _: InterruptedException if task.reasonIfKilled.isDefined => + case NonFatal(_) if task != null && task.reasonIfKilled.isDefined => val killReason = task.reasonIfKilled.getOrElse("unknown reason") logInfo(s"Executor interrupted and killed $taskName (TID $taskId), reason: $killReason") setTaskFinishedAndClearInterruptStatus() diff --git a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala index 2c947556dfd3..735f4454e299 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala @@ -572,7 +572,13 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu // first attempt will hang if (!SparkContextSuite.isTaskStarted) { SparkContextSuite.isTaskStarted = true - Thread.sleep(9999999) + try { + Thread.sleep(9999999) + } catch { + case t: Throwable => + // SPARK-20217 should not fail stage if task throws non-interrupted exception + throw new RuntimeException("killed") + } } // second attempt succeeds immediately } From e156b5dd39dc1992077fe06e0f8be810c49c8255 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Thu, 6 Apr 2017 09:41:32 +0200 Subject: [PATCH 223/512] [SPARK-19953][ML] Random Forest Models use parent UID when being fit ## What changes were proposed in this pull request? The ML `RandomForestClassificationModel` and `RandomForestRegressionModel` were not using the estimator parent UID when being fit. This change fixes that so the models can be properly be identified with their parents. ## How was this patch tested?Existing tests. Added check to verify that model uid matches that of the parent, then renamed `checkCopy` to `checkCopyAndUids` and verified that it was called by one test for each ML algorithm. Author: Bryan Cutler Closes #17296 from BryanCutler/rfmodels-use-parent-uid-SPARK-19953. --- .../RandomForestClassifier.scala | 2 +- .../ml/regression/RandomForestRegressor.scala | 2 +- .../org/apache/spark/ml/PipelineSuite.scala | 2 +- .../DecisionTreeClassifierSuite.scala | 3 +- .../classification/GBTClassifierSuite.scala | 6 ++-- .../ml/classification/LinearSVCSuite.scala | 3 +- .../LogisticRegressionSuite.scala | 3 +- .../MultilayerPerceptronClassifierSuite.scala | 2 +- .../ml/classification/NaiveBayesSuite.scala | 1 + .../ml/classification/OneVsRestSuite.scala | 3 +- .../RandomForestClassifierSuite.scala | 3 +- .../ml/clustering/BisectingKMeansSuite.scala | 3 +- .../ml/clustering/GaussianMixtureSuite.scala | 3 +- .../spark/ml/clustering/KMeansSuite.scala | 3 +- .../apache/spark/ml/clustering/LDASuite.scala | 4 +-- .../BucketedRandomProjectionLSHSuite.scala | 3 +- .../spark/ml/feature/ChiSqSelectorSuite.scala | 9 +++-- .../ml/feature/CountVectorizerSuite.scala | 9 ++--- .../apache/spark/ml/feature/IDFSuite.scala | 8 +++-- .../org/apache/spark/ml/feature/LSHTest.scala | 4 ++- .../spark/ml/feature/MaxAbsScalerSuite.scala | 3 +- .../spark/ml/feature/MinHashLSHSuite.scala | 2 +- .../spark/ml/feature/MinMaxScalerSuite.scala | 3 +- .../apache/spark/ml/feature/PCASuite.scala | 8 ++--- .../spark/ml/feature/RFormulaSuite.scala | 2 +- .../ml/feature/StandardScalerSuite.scala | 7 ++-- .../spark/ml/feature/StringIndexerSuite.scala | 7 ++-- .../spark/ml/feature/VectorIndexerSuite.scala | 3 +- .../spark/ml/feature/Word2VecSuite.scala | 7 ++-- .../apache/spark/ml/fpm/FPGrowthSuite.scala | 7 ++-- .../spark/ml/recommendation/ALSSuite.scala | 3 +- .../AFTSurvivalRegressionSuite.scala | 3 +- .../DecisionTreeRegressorSuite.scala | 7 ++-- .../ml/regression/GBTRegressorSuite.scala | 3 +- .../GeneralizedLinearRegressionSuite.scala | 3 +- .../regression/IsotonicRegressionSuite.scala | 3 +- .../ml/regression/LinearRegressionSuite.scala | 3 +- .../RandomForestRegressorSuite.scala | 2 ++ .../spark/ml/tuning/CrossValidatorSuite.scala | 3 +- .../ml/tuning/TrainValidationSplitSuite.scala | 35 +++++++++---------- .../apache/spark/ml/util/MLTestingUtils.scala | 8 +++-- 41 files changed, 98 insertions(+), 100 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index ce834f1d17e0..ab4c23520928 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -140,7 +140,7 @@ class RandomForestClassifier @Since("1.4.0") ( .map(_.asInstanceOf[DecisionTreeClassificationModel]) val numFeatures = oldDataset.first().features.size - val m = new RandomForestClassificationModel(trees, numFeatures, numClasses) + val m = new RandomForestClassificationModel(uid, trees, numFeatures, numClasses) instr.logSuccess(m) m } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala index 2f524a8c5784..a58da50fad97 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala @@ -131,7 +131,7 @@ class RandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S .map(_.asInstanceOf[DecisionTreeRegressionModel]) val numFeatures = oldDataset.first().features.size - val m = new RandomForestRegressionModel(trees, numFeatures) + val m = new RandomForestRegressionModel(uid, trees, numFeatures) instr.logSuccess(m) m } diff --git a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala index dafc6c200f95..4cdbf845ae4f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala @@ -79,7 +79,7 @@ class PipelineSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul .setStages(Array(estimator0, transformer1, estimator2, transformer3)) val pipelineModel = pipeline.fit(dataset0) - MLTestingUtils.checkCopy(pipelineModel) + MLTestingUtils.checkCopyAndUids(pipeline, pipelineModel) assert(pipelineModel.stages.length === 4) assert(pipelineModel.stages(0).eq(model0)) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala index 964fcfbdd87a..918ab27e2730 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala @@ -249,8 +249,7 @@ class DecisionTreeClassifierSuite val newData: DataFrame = TreeTests.setMetadata(rdd, categoricalFeatures, numClasses) val newTree = dt.fit(newData) - // copied model must have the same parent. - MLTestingUtils.checkCopy(newTree) + MLTestingUtils.checkCopyAndUids(dt, newTree) val predictions = newTree.transform(newData) .select(newTree.getPredictionCol, newTree.getRawPredictionCol, newTree.getProbabilityCol) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala index 0cddb37281b3..1f79e0d4e622 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala @@ -97,8 +97,7 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext assert(model.getProbabilityCol === "probability") assert(model.hasParent) - // copied model must have the same parent. - MLTestingUtils.checkCopy(model) + MLTestingUtils.checkCopyAndUids(gbt, model) } test("setThreshold, getThreshold") { @@ -261,8 +260,7 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext .setSeed(123) val model = gbt.fit(df) - // copied model must have the same parent. - MLTestingUtils.checkCopy(model) + MLTestingUtils.checkCopyAndUids(gbt, model) sc.checkpointDir = None Utils.deleteRecursively(tempDir) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala index c763a4cef1af..2f87afc23fe7 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala @@ -124,8 +124,7 @@ class LinearSVCSuite extends SparkFunSuite with MLlibTestSparkContext with Defau assert(model.hasParent) assert(model.numFeatures === 2) - // copied model must have the same parent. - MLTestingUtils.checkCopy(model) + MLTestingUtils.checkCopyAndUids(lsvc, model) } test("linear svc doesn't fit intercept when fitIntercept is off") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index f0648d0936a1..c858b9bbfc25 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -142,8 +142,7 @@ class LogisticRegressionSuite assert(model.intercept !== 0.0) assert(model.hasParent) - // copied model must have the same parent. - MLTestingUtils.checkCopy(model) + MLTestingUtils.checkCopyAndUids(lr, model) assert(model.hasSummary) val copiedModel = model.copy(ParamMap.empty) assert(copiedModel.hasSummary) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala index 7700099caac3..ce54c3df4f3f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala @@ -74,8 +74,8 @@ class MultilayerPerceptronClassifierSuite .setMaxIter(100) .setSolver("l-bfgs") val model = trainer.fit(dataset) - MLTestingUtils.checkCopy(model) val result = model.transform(dataset) + MLTestingUtils.checkCopyAndUids(trainer, model) val predictionAndLabels = result.select("prediction", "label").collect() predictionAndLabels.foreach { case Row(p: Double, l: Double) => assert(p == l) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala index d41c5b533ded..b56f8e19ca53 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala @@ -149,6 +149,7 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa validateModelFit(pi, theta, model) assert(model.hasParent) + MLTestingUtils.checkCopyAndUids(nb, model) val validationDataset = generateNaiveBayesInput(piArray, thetaArray, nPoints, 17, "multinomial").toDF() diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala index aacb7921b835..c02e38ad64e3 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala @@ -76,8 +76,7 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext with Defau assert(ova.getPredictionCol === "prediction") val ovaModel = ova.fit(dataset) - // copied model must have the same parent. - MLTestingUtils.checkCopy(ovaModel) + MLTestingUtils.checkCopyAndUids(ova, ovaModel) assert(ovaModel.models.length === numClasses) val transformedDataset = ovaModel.transform(dataset) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala index c3003cec73b4..ca2954d2f32c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala @@ -141,8 +141,7 @@ class RandomForestClassifierSuite val df: DataFrame = TreeTests.setMetadata(rdd, categoricalFeatures, numClasses) val model = rf.fit(df) - // copied model must have the same parent. - MLTestingUtils.checkCopy(model) + MLTestingUtils.checkCopyAndUids(rf, model) val predictions = model.transform(df) .select(rf.getPredictionCol, rf.getRawPredictionCol, rf.getProbabilityCol) diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala index 200a892f6c69..fa7471fa2d65 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala @@ -47,8 +47,7 @@ class BisectingKMeansSuite assert(bkm.getMinDivisibleClusterSize === 1.0) val model = bkm.setMaxIter(1).fit(dataset) - // copied model must have the same parent - MLTestingUtils.checkCopy(model) + MLTestingUtils.checkCopyAndUids(bkm, model) assert(model.hasSummary) val copiedModel = model.copy(ParamMap.empty) assert(copiedModel.hasSummary) diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala index 61da897b666f..08b800b7e418 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala @@ -77,8 +77,7 @@ class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext assert(gm.getTol === 0.01) val model = gm.setMaxIter(1).fit(dataset) - // copied model must have the same parent - MLTestingUtils.checkCopy(model) + MLTestingUtils.checkCopyAndUids(gm, model) assert(model.hasSummary) val copiedModel = model.copy(ParamMap.empty) assert(copiedModel.hasSummary) diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala index ca05b9c389f6..119fe1dead9a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala @@ -52,8 +52,7 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR assert(kmeans.getTol === 1e-4) val model = kmeans.setMaxIter(1).fit(dataset) - // copied model must have the same parent - MLTestingUtils.checkCopy(model) + MLTestingUtils.checkCopyAndUids(kmeans, model) assert(model.hasSummary) val copiedModel = model.copy(ParamMap.empty) assert(copiedModel.hasSummary) diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala index 75aa0be61a3e..b4fe63a89f87 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala @@ -176,7 +176,7 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead val lda = new LDA().setK(k).setSeed(1).setOptimizer("online").setMaxIter(2) val model = lda.fit(dataset) - MLTestingUtils.checkCopy(model) + MLTestingUtils.checkCopyAndUids(lda, model) assert(model.isInstanceOf[LocalLDAModel]) assert(model.vocabSize === vocabSize) @@ -221,7 +221,7 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead val lda = new LDA().setK(k).setSeed(1).setOptimizer("em").setMaxIter(2) val model_ = lda.fit(dataset) - MLTestingUtils.checkCopy(model_) + MLTestingUtils.checkCopyAndUids(lda, model_) assert(model_.isInstanceOf[DistributedLDAModel]) val model = model_.asInstanceOf[DistributedLDAModel] diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala index cc81da5c66e6..7175c721bff3 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala @@ -94,7 +94,8 @@ class BucketedRandomProjectionLSHSuite unitVectors.foreach { v: Vector => assert(Vectors.norm(v, 2.0) ~== 1.0 absTol 1e-14) } - MLTestingUtils.checkCopy(brpModel) + + MLTestingUtils.checkCopyAndUids(brp, brpModel) } test("BucketedRandomProjectionLSH: test of LSH property") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala index d6925da97d57..c83909c4498f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala @@ -119,7 +119,8 @@ class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext test("Test Chi-Square selector: numTopFeatures") { val selector = new ChiSqSelector() .setOutputCol("filtered").setSelectorType("numTopFeatures").setNumTopFeatures(1) - ChiSqSelectorSuite.testSelector(selector, dataset) + val model = ChiSqSelectorSuite.testSelector(selector, dataset) + MLTestingUtils.checkCopyAndUids(selector, model) } test("Test Chi-Square selector: percentile") { @@ -166,11 +167,13 @@ class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext object ChiSqSelectorSuite { - private def testSelector(selector: ChiSqSelector, dataset: Dataset[_]): Unit = { - selector.fit(dataset).transform(dataset).select("filtered", "topFeature").collect() + private def testSelector(selector: ChiSqSelector, dataset: Dataset[_]): ChiSqSelectorModel = { + val selectorModel = selector.fit(dataset) + selectorModel.transform(dataset).select("filtered", "topFeature").collect() .foreach { case Row(vec1: Vector, vec2: Vector) => assert(vec1 ~== vec2 absTol 1e-1) } + selectorModel } /** diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala index 69d3033bb218..f213145f1ba0 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.Row @@ -68,10 +68,11 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext val cv = new CountVectorizer() .setInputCol("words") .setOutputCol("features") - .fit(df) - assert(cv.vocabulary.toSet === Set("a", "b", "c", "d", "e")) + val cvm = cv.fit(df) + MLTestingUtils.checkCopyAndUids(cv, cvm) + assert(cvm.vocabulary.toSet === Set("a", "b", "c", "d", "e")) - cv.transform(df).select("features", "expected").collect().foreach { + cvm.transform(df).select("features", "expected").collect().foreach { case Row(features: Vector, expected: Vector) => assert(features ~== expected absTol 1e-14) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala index 5325d95526a5..005edf73d29b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.feature.{IDFModel => OldIDFModel} import org.apache.spark.mllib.linalg.VectorImplicits._ @@ -65,10 +65,12 @@ class IDFSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead val df = data.zip(expected).toSeq.toDF("features", "expected") - val idfModel = new IDF() + val idfEst = new IDF() .setInputCol("features") .setOutputCol("idfValue") - .fit(df) + val idfModel = idfEst.fit(df) + + MLTestingUtils.checkCopyAndUids(idfEst, idfModel) idfModel.transform(df).select("idfValue", "expected").collect().foreach { case Row(x: Vector, y: Vector) => diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/LSHTest.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/LSHTest.scala index a9b559f7ba64..dd4dd62b8cfe 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/LSHTest.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/LSHTest.scala @@ -18,7 +18,7 @@ package org.apache.spark.ml.feature import org.apache.spark.ml.linalg.{Vector, VectorUDT} -import org.apache.spark.ml.util.SchemaUtils +import org.apache.spark.ml.util.{MLTestingUtils, SchemaUtils} import org.apache.spark.sql.Dataset import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.DataTypes @@ -58,6 +58,8 @@ private[ml] object LSHTest { val outputCol = model.getOutputCol val transformedData = model.transform(dataset) + MLTestingUtils.checkCopyAndUids(lsh, model) + // Check output column type SchemaUtils.checkColumnType( transformedData.schema, model.getOutputCol, DataTypes.createArrayType(new VectorUDT)) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/MaxAbsScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/MaxAbsScalerSuite.scala index a12174493b86..918da4f9388d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/MaxAbsScalerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/MaxAbsScalerSuite.scala @@ -50,8 +50,7 @@ class MaxAbsScalerSuite extends SparkFunSuite with MLlibTestSparkContext with De assert(vector1.equals(vector2), s"MaxAbsScaler ut error: $vector2 should be $vector1") } - // copied model must have the same parent. - MLTestingUtils.checkCopy(model) + MLTestingUtils.checkCopyAndUids(scaler, model) } test("MaxAbsScaler read/write") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala index 0ddf097a6eb2..96df68dbdf05 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala @@ -63,7 +63,7 @@ class MinHashLSHSuite extends SparkFunSuite with MLlibTestSparkContext with Defa .setOutputCol("values") val model = mh.fit(dataset) assert(mh.uid === model.uid) - MLTestingUtils.checkCopy(model) + MLTestingUtils.checkCopyAndUids(mh, model) } test("hashFunction") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala index b79eeb2d75ef..51db74eb739c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala @@ -53,8 +53,7 @@ class MinMaxScalerSuite extends SparkFunSuite with MLlibTestSparkContext with De assert(vector1.equals(vector2), "Transformed vector is different with expected.") } - // copied model must have the same parent. - MLTestingUtils.checkCopy(model) + MLTestingUtils.checkCopyAndUids(scaler, model) } test("MinMaxScaler arguments max must be larger than min") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala index a60e87590f06..3067a52a4df7 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala @@ -58,12 +58,12 @@ class PCASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead .setInputCol("features") .setOutputCol("pca_features") .setK(3) - .fit(df) - // copied model must have the same parent. - MLTestingUtils.checkCopy(pca) + val pcaModel = pca.fit(df) - pca.transform(df).select("pca_features", "expected").collect().foreach { + MLTestingUtils.checkCopyAndUids(pca, pcaModel) + + pcaModel.transform(df).select("pca_features", "expected").collect().foreach { case Row(x: Vector, y: Vector) => assert(x ~== y absTol 1e-5, "Transformed vector is different with expected vector.") } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala index 5cfd59e6b88a..fbebd75d70ac 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala @@ -37,7 +37,7 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul val formula = new RFormula().setFormula("id ~ v1 + v2") val original = Seq((0, 1.0, 3.0), (2, 2.0, 5.0)).toDF("id", "v1", "v2") val model = formula.fit(original) - MLTestingUtils.checkCopy(model) + MLTestingUtils.checkCopyAndUids(formula, model) val result = model.transform(original) val resultSchema = model.transformSchema(original.schema) val expected = Seq( diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala index a928f9363301..350ba44baa1e 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, Row} @@ -77,10 +77,11 @@ class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext test("Standardization with default parameter") { val df0 = data.zip(resWithStd).toSeq.toDF("features", "expected") - val standardScaler0 = new StandardScaler() + val standardScalerEst0 = new StandardScaler() .setInputCol("features") .setOutputCol("standardized_features") - .fit(df0) + val standardScaler0 = standardScalerEst0.fit(df0) + MLTestingUtils.checkCopyAndUids(standardScalerEst0, standardScaler0) assertResult(standardScaler0.transform(df0)) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala index 8d9042b31e03..5634d4210f47 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala @@ -45,12 +45,11 @@ class StringIndexerSuite val indexer = new StringIndexer() .setInputCol("label") .setOutputCol("labelIndex") - .fit(df) + val indexerModel = indexer.fit(df) - // copied model must have the same parent. - MLTestingUtils.checkCopy(indexer) + MLTestingUtils.checkCopyAndUids(indexer, indexerModel) - val transformed = indexer.transform(df) + val transformed = indexerModel.transform(df) val attr = Attribute.fromStructField(transformed.schema("labelIndex")) .asInstanceOf[NominalAttribute] assert(attr.values.get === Array("a", "c", "b")) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala index b28ce2ab45b4..f2cca8aa82e8 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala @@ -114,8 +114,7 @@ class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext val vectorIndexer = getIndexer val model = vectorIndexer.fit(densePoints1) // vectors of length 3 - // copied model must have the same parent. - MLTestingUtils.checkCopy(model) + MLTestingUtils.checkCopyAndUids(vectorIndexer, model) model.transform(densePoints1) // should work model.transform(sparsePoints1) // should work diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala index 2043a16c15f1..a6a1c2b4f32b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala @@ -57,15 +57,14 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul val docDF = doc.zip(expected).toDF("text", "expected") - val model = new Word2Vec() + val w2v = new Word2Vec() .setVectorSize(3) .setInputCol("text") .setOutputCol("result") .setSeed(42L) - .fit(docDF) + val model = w2v.fit(docDF) - // copied model must have the same parent. - MLTestingUtils.checkCopy(model) + MLTestingUtils.checkCopyAndUids(w2v, model) // These expectations are just magic values, characterizing the current // behavior. The test needs to be updated to be more general, see SPARK-11502 diff --git a/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala index 6bec057511cd..6806cb03bc42 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala @@ -17,9 +17,10 @@ package org.apache.spark.ml.fpm import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession} +import org.apache.spark.sql.{DataFrame, Dataset, SparkSession} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ @@ -121,7 +122,9 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul .setMinConfidence(0.5678) assert(fpGrowth.getMinSupport === 0.4567) assert(model.getMinConfidence === 0.5678) - MLTestingUtils.checkCopy(model) + MLTestingUtils.checkCopyAndUids(fpGrowth, model) + ParamsSuite.checkParams(fpGrowth) + ParamsSuite.checkParams(model) } test("read/write") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala index a177ed13bf8e..7574af3d77ea 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala @@ -409,8 +409,7 @@ class ALSSuite logInfo(s"Test RMSE is $rmse.") assert(rmse < targetRMSE) - // copied model must have the same parent. - MLTestingUtils.checkCopy(model) + MLTestingUtils.checkCopyAndUids(als, model) } test("exact rank-1 matrix") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala index 708185a0943d..fb39e50a8355 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala @@ -83,8 +83,7 @@ class AFTSurvivalRegressionSuite .setQuantilesCol("quantiles") .fit(datasetUnivariate) - // copied model must have the same parent. - MLTestingUtils.checkCopy(model) + MLTestingUtils.checkCopyAndUids(aftr, model) model.transform(datasetUnivariate) .select("label", "prediction", "quantiles") diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala index 0e91284d03d9..642f266891b5 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala @@ -69,11 +69,12 @@ class DecisionTreeRegressorSuite test("copied model must have the same parent") { val categoricalFeatures = Map(0 -> 2, 1 -> 2) val df = TreeTests.setMetadata(categoricalDataPointsRDD, categoricalFeatures, numClasses = 0) - val model = new DecisionTreeRegressor() + val dtr = new DecisionTreeRegressor() .setImpurity("variance") .setMaxDepth(2) - .setMaxBins(8).fit(df) - MLTestingUtils.checkCopy(model) + .setMaxBins(8) + val model = dtr.fit(df) + MLTestingUtils.checkCopyAndUids(dtr, model) } test("predictVariance") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala index 03c2f97797bc..2da25f7e0100 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala @@ -90,8 +90,7 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext .setMaxIter(2) val model = gbt.fit(df) - // copied model must have the same parent. - MLTestingUtils.checkCopy(model) + MLTestingUtils.checkCopyAndUids(gbt, model) val preds = model.transform(df) val predictions = preds.select("prediction").rdd.map(_.getDouble(0)) // Checks based on SPARK-8736 (to ensure it is not doing classification) diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala index 401911763fa3..f7c7c001a36a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala @@ -197,8 +197,7 @@ class GeneralizedLinearRegressionSuite val model = glr.setFamily("gaussian").setLink("identity") .fit(datasetGaussianIdentity) - // copied model must have the same parent. - MLTestingUtils.checkCopy(model) + MLTestingUtils.checkCopyAndUids(glr, model) assert(model.hasSummary) val copiedModel = model.copy(ParamMap.empty) assert(copiedModel.hasSummary) diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala index f41a3601b1fa..180f5f7ce5ab 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala @@ -93,8 +93,7 @@ class IsotonicRegressionSuite val model = ir.fit(dataset) - // copied model must have the same parent. - MLTestingUtils.checkCopy(model) + MLTestingUtils.checkCopyAndUids(ir, model) model.transform(dataset) .select("label", "features", "prediction", "weight") diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala index c6a267b7283d..e7bd4eb9e0ad 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -148,8 +148,7 @@ class LinearRegressionSuite assert(lir.getSolver == "auto") val model = lir.fit(datasetWithDenseFeature) - // copied model must have the same parent. - MLTestingUtils.checkCopy(model) + MLTestingUtils.checkCopyAndUids(lir, model) assert(model.hasSummary) val copiedModel = model.copy(ParamMap.empty) assert(copiedModel.hasSummary) diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala index 3bf0445ebd3d..8b8e8a655f47 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala @@ -90,6 +90,8 @@ class RandomForestRegressorSuite extends SparkFunSuite with MLlibTestSparkContex val model = rf.fit(df) + MLTestingUtils.checkCopyAndUids(rf, model) + val importances = model.featureImportances val mostImportantFeature = importances.argmax assert(mostImportantFeature === 1) diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala index 7116265474f2..2b4e6b53e4f8 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala @@ -58,8 +58,7 @@ class CrossValidatorSuite .setNumFolds(3) val cvModel = cv.fit(dataset) - // copied model must have the same paren. - MLTestingUtils.checkCopy(cvModel) + MLTestingUtils.checkCopyAndUids(cv, cvModel) val parent = cvModel.bestModel.parent.asInstanceOf[LogisticRegression] assert(parent.getRegParam === 0.001) diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala index 4463a9b6e543..a34f930aa11c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala @@ -45,18 +45,18 @@ class TrainValidationSplitSuite .addGrid(lr.maxIter, Array(0, 10)) .build() val eval = new BinaryClassificationEvaluator - val cv = new TrainValidationSplit() + val tvs = new TrainValidationSplit() .setEstimator(lr) .setEstimatorParamMaps(lrParamMaps) .setEvaluator(eval) .setTrainRatio(0.5) .setSeed(42L) - val cvModel = cv.fit(dataset) - val parent = cvModel.bestModel.parent.asInstanceOf[LogisticRegression] - assert(cv.getTrainRatio === 0.5) + val tvsModel = tvs.fit(dataset) + val parent = tvsModel.bestModel.parent.asInstanceOf[LogisticRegression] + assert(tvs.getTrainRatio === 0.5) assert(parent.getRegParam === 0.001) assert(parent.getMaxIter === 10) - assert(cvModel.validationMetrics.length === lrParamMaps.length) + assert(tvsModel.validationMetrics.length === lrParamMaps.length) } test("train validation with linear regression") { @@ -71,28 +71,27 @@ class TrainValidationSplitSuite .addGrid(trainer.maxIter, Array(0, 10)) .build() val eval = new RegressionEvaluator() - val cv = new TrainValidationSplit() + val tvs = new TrainValidationSplit() .setEstimator(trainer) .setEstimatorParamMaps(lrParamMaps) .setEvaluator(eval) .setTrainRatio(0.5) .setSeed(42L) - val cvModel = cv.fit(dataset) + val tvsModel = tvs.fit(dataset) - // copied model must have the same paren. - MLTestingUtils.checkCopy(cvModel) + MLTestingUtils.checkCopyAndUids(tvs, tvsModel) - val parent = cvModel.bestModel.parent.asInstanceOf[LinearRegression] + val parent = tvsModel.bestModel.parent.asInstanceOf[LinearRegression] assert(parent.getRegParam === 0.001) assert(parent.getMaxIter === 10) - assert(cvModel.validationMetrics.length === lrParamMaps.length) + assert(tvsModel.validationMetrics.length === lrParamMaps.length) eval.setMetricName("r2") - val cvModel2 = cv.fit(dataset) - val parent2 = cvModel2.bestModel.parent.asInstanceOf[LinearRegression] + val tvsModel2 = tvs.fit(dataset) + val parent2 = tvsModel2.bestModel.parent.asInstanceOf[LinearRegression] assert(parent2.getRegParam === 0.001) assert(parent2.getMaxIter === 10) - assert(cvModel2.validationMetrics.length === lrParamMaps.length) + assert(tvsModel2.validationMetrics.length === lrParamMaps.length) } test("transformSchema should check estimatorParamMaps") { @@ -104,17 +103,17 @@ class TrainValidationSplitSuite .addGrid(est.inputCol, Array("input1", "input2")) .build() - val cv = new TrainValidationSplit() + val tvs = new TrainValidationSplit() .setEstimator(est) .setEstimatorParamMaps(paramMaps) .setEvaluator(eval) .setTrainRatio(0.5) - cv.transformSchema(new StructType()) // This should pass. + tvs.transformSchema(new StructType()) // This should pass. val invalidParamMaps = paramMaps :+ ParamMap(est.inputCol -> "") - cv.setEstimatorParamMaps(invalidParamMaps) + tvs.setEstimatorParamMaps(invalidParamMaps) intercept[IllegalArgumentException] { - cv.transformSchema(new StructType()) + tvs.transformSchema(new StructType()) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala index 578f31c8e7db..bef79e634f75 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala @@ -31,11 +31,15 @@ import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ object MLTestingUtils extends SparkFunSuite { - def checkCopy(model: Model[_]): Unit = { + + def checkCopyAndUids[T <: Estimator[_]](estimator: T, model: Model[_]): Unit = { + assert(estimator.uid === model.uid, "Model uid does not match parent estimator") + + // copied model must have the same parent val copied = model.copy(ParamMap.empty) .asInstanceOf[Model[_]] - assert(copied.parent.uid == model.parent.uid) assert(copied.parent == model.parent) + assert(copied.parent.uid == model.parent.uid) } def checkNumericTypes[M <: Model[M], T <: Estimator[M]]( From c8fc1f3badf61bcfc4bd8eeeb61f73078ca068d1 Mon Sep 17 00:00:00 2001 From: Kalvin Chau Date: Thu, 6 Apr 2017 09:14:31 +0100 Subject: [PATCH 224/512] [SPARK-20085][MESOS] Configurable mesos labels for executors ## What changes were proposed in this pull request? Add spark.mesos.task.labels configuration option to add mesos key:value labels to the executor. "k1:v1,k2:v2" as the format, colons separating key-value and commas to list out more than one. Discussion of labels with mgummelt at #17404 ## How was this patch tested? Added unit tests to verify labels were added correctly, with incorrect labels being ignored and added a test to test the name of the executor. Tested with: `./build/sbt -Pmesos mesos/test` Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Kalvin Chau Closes #17413 from kalvinnchau/mesos-labels. --- docs/running-on-mesos.md | 9 ++++ .../MesosCoarseGrainedSchedulerBackend.scala | 24 ++++++++++ ...osCoarseGrainedSchedulerBackendSuite.scala | 46 +++++++++++++++++++ 3 files changed, 79 insertions(+) diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md index 8d5ad12cb85b..ef01cfe4b92c 100644 --- a/docs/running-on-mesos.md +++ b/docs/running-on-mesos.md @@ -367,6 +367,15 @@ See the [configuration page](configuration.html) for information on Spark config

    [host_path:]container_path[:ro|:rw]
    + + spark.mesos.task.labels + (none) + + Set the Mesos labels to add to each task. Labels are free-form key-value pairs. + Key-value pairs should be separated by a colon, and commas used to list more than one. + Ex. key:value,key2:value2. + + spark.mesos.executor.home driver side SPARK_HOME diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala index 5bdc2a2b840e..2a36ec4fa811 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala @@ -67,6 +67,8 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( private val maxGpus = conf.getInt("spark.mesos.gpus.max", 0) + private val taskLabels = conf.get("spark.mesos.task.labels", "") + private[this] val shutdownTimeoutMS = conf.getTimeAsMs("spark.mesos.coarse.shutdownTimeout", "10s") .ensuring(_ >= 0, "spark.mesos.coarse.shutdownTimeout must be >= 0") @@ -408,6 +410,13 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( taskBuilder.addAllResources(resourcesToUse.asJava) taskBuilder.setContainer(MesosSchedulerBackendUtil.containerInfo(sc.conf)) + val labelsBuilder = taskBuilder.getLabelsBuilder + val labels = buildMesosLabels().asJava + + labelsBuilder.addAllLabels(labels) + + taskBuilder.setLabels(labelsBuilder) + tasks(offer.getId) ::= taskBuilder.build() remainingResources(offerId) = resourcesLeft.asJava totalCoresAcquired += taskCPUs @@ -422,6 +431,21 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( tasks.toMap } + private def buildMesosLabels(): List[Label] = { + taskLabels.split(",").flatMap(label => + label.split(":") match { + case Array(key, value) => + Some(Label.newBuilder() + .setKey(key) + .setValue(value) + .build()) + case _ => + logWarning(s"Unable to parse $label into a key:value label for the task.") + None + } + ).toList + } + /** Extracts task needed resources from a list of available resources. */ private def partitionTaskResources( resources: JList[Resource], diff --git a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala index eb83926ae410..c040f05d93b3 100644 --- a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala +++ b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala @@ -475,6 +475,52 @@ class MesosCoarseGrainedSchedulerBackendSuite extends SparkFunSuite assert(launchedTasks.head.getName == "test-mesos-dynamic-alloc 0") } + test("mesos sets configurable labels on tasks") { + val taskLabelsString = "mesos:test,label:test" + setBackend(Map( + "spark.mesos.task.labels" -> taskLabelsString + )) + + // Build up the labels + val taskLabels = Protos.Labels.newBuilder() + .addLabels(Protos.Label.newBuilder() + .setKey("mesos").setValue("test").build()) + .addLabels(Protos.Label.newBuilder() + .setKey("label").setValue("test").build()) + .build() + + val offers = List(Resources(backend.executorMemory(sc), 1)) + offerResources(offers) + val launchedTasks = verifyTaskLaunched(driver, "o1") + + val labels = launchedTasks.head.getLabels + + assert(launchedTasks.head.getLabels.equals(taskLabels)) + } + + test("mesos ignored invalid labels and sets configurable labels on tasks") { + val taskLabelsString = "mesos:test,label:test,incorrect:label:here" + setBackend(Map( + "spark.mesos.task.labels" -> taskLabelsString + )) + + // Build up the labels + val taskLabels = Protos.Labels.newBuilder() + .addLabels(Protos.Label.newBuilder() + .setKey("mesos").setValue("test").build()) + .addLabels(Protos.Label.newBuilder() + .setKey("label").setValue("test").build()) + .build() + + val offers = List(Resources(backend.executorMemory(sc), 1)) + offerResources(offers) + val launchedTasks = verifyTaskLaunched(driver, "o1") + + val labels = launchedTasks.head.getLabels + + assert(launchedTasks.head.getLabels.equals(taskLabels)) + } + test("mesos supports spark.mesos.network.name") { setBackend(Map( "spark.mesos.network.name" -> "test-network-name" From d009fb369bbea0df81bbcf9c8028d14cfcaa683b Mon Sep 17 00:00:00 2001 From: setjet Date: Thu, 6 Apr 2017 09:43:07 +0100 Subject: [PATCH 225/512] [SPARK-20064][PYSPARK] Bump the PySpark verison number to 2.2 ## What changes were proposed in this pull request? PySpark version in version.py was lagging behind Versioning is in line with PEP 440: https://www.python.org/dev/peps/pep-0440/ ## How was this patch tested? Simply rebuild the project with existing tests Author: setjet Author: Ruben Janssen Closes #17523 from setjet/SPARK-20064. --- python/pyspark/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/version.py b/python/pyspark/version.py index 08a301695fda..41bf8c269b79 100644 --- a/python/pyspark/version.py +++ b/python/pyspark/version.py @@ -16,4 +16,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "2.1.0.dev0" +__version__ = "2.2.0.dev0" From bccc330193217b2ec9660e06f1db6dd58f7af5d8 Mon Sep 17 00:00:00 2001 From: Felix Cheung Date: Thu, 6 Apr 2017 09:09:43 -0700 Subject: [PATCH 226/512] [SPARK-20196][PYTHON][SQL] update doc for catalog functions for all languages, add pyspark refreshByPath API ## What changes were proposed in this pull request? Update doc to remove external for createTable, add refreshByPath in python ## How was this patch tested? manual Author: Felix Cheung Closes #17512 from felixcheung/catalogdoc. --- R/pkg/R/SQLContext.R | 11 ++-- R/pkg/R/catalog.R | 52 +++++++++++-------- python/pyspark/sql/catalog.py | 27 +++++++--- python/pyspark/sql/context.py | 2 +- .../apache/spark/sql/catalog/Catalog.scala | 17 +++--- .../spark/sql/internal/CatalogImpl.scala | 22 +++++--- 6 files changed, 79 insertions(+), 52 deletions(-) diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R index a1edef7608fa..c2a1e240ad39 100644 --- a/R/pkg/R/SQLContext.R +++ b/R/pkg/R/SQLContext.R @@ -544,12 +544,15 @@ sql <- function(x, ...) { dispatchFunc("sql(sqlQuery)", x, ...) } -#' Create a SparkDataFrame from a SparkSQL Table +#' Create a SparkDataFrame from a SparkSQL table or view #' -#' Returns the specified Table as a SparkDataFrame. The Table must have already been registered -#' in the SparkSession. +#' Returns the specified table or view as a SparkDataFrame. The table or view must already exist or +#' have already been registered in the SparkSession. #' -#' @param tableName The SparkSQL Table to convert to a SparkDataFrame. +#' @param tableName the qualified or unqualified name that designates a table or view. If a database +#' is specified, it identifies the table/view from the database. +#' Otherwise, it first attempts to find a temporary view with the given name +#' and then match the table/view from the current database. #' @return SparkDataFrame #' @rdname tableToDF #' @name tableToDF diff --git a/R/pkg/R/catalog.R b/R/pkg/R/catalog.R index 07a89f763cde..4b7f841b55dd 100644 --- a/R/pkg/R/catalog.R +++ b/R/pkg/R/catalog.R @@ -65,7 +65,8 @@ createExternalTable <- function(x, ...) { #' #' Caches the specified table in-memory. #' -#' @param tableName The name of the table being cached +#' @param tableName the qualified or unqualified name that designates a table. If no database +#' identifier is provided, it refers to a table in the current database. #' @return SparkDataFrame #' @rdname cacheTable #' @export @@ -94,7 +95,8 @@ cacheTable <- function(x, ...) { #' #' Removes the specified table from the in-memory cache. #' -#' @param tableName The name of the table being uncached +#' @param tableName the qualified or unqualified name that designates a table. If no database +#' identifier is provided, it refers to a table in the current database. #' @return SparkDataFrame #' @rdname uncacheTable #' @export @@ -162,6 +164,7 @@ clearCache <- function() { #' @method dropTempTable default #' @note dropTempTable since 1.4.0 dropTempTable.default <- function(tableName) { + .Deprecated("dropTempView", old = "dropTempTable") if (class(tableName) != "character") { stop("tableName must be a string.") } @@ -169,7 +172,6 @@ dropTempTable.default <- function(tableName) { } dropTempTable <- function(x, ...) { - .Deprecated("dropTempView") dispatchFunc("dropTempView(viewName)", x, ...) } @@ -178,7 +180,7 @@ dropTempTable <- function(x, ...) { #' Drops the temporary view with the given view name in the catalog. #' If the view has been cached before, then it will also be uncached. #' -#' @param viewName the name of the view to be dropped. +#' @param viewName the name of the temporary view to be dropped. #' @return TRUE if the view is dropped successfully, FALSE otherwise. #' @rdname dropTempView #' @name dropTempView @@ -317,10 +319,10 @@ listDatabases <- function() { dataFrame(callJMethod(callJMethod(catalog, "listDatabases"), "toDF")) } -#' Returns a list of tables in the specified database +#' Returns a list of tables or views in the specified database #' -#' Returns a list of tables in the specified database. -#' This includes all temporary tables. +#' Returns a list of tables or views in the specified database. +#' This includes all temporary views. #' #' @param databaseName (optional) name of the database #' @return a SparkDataFrame of the list of tables. @@ -349,11 +351,13 @@ listTables <- function(databaseName = NULL) { dataFrame(callJMethod(jdst, "toDF")) } -#' Returns a list of columns for the given table in the specified database +#' Returns a list of columns for the given table/view in the specified database #' -#' Returns a list of columns for the given table in the specified database. +#' Returns a list of columns for the given table/view in the specified database. #' -#' @param tableName a name of the table. +#' @param tableName the qualified or unqualified name that designates a table/view. If no database +#' identifier is provided, it refers to a table/view in the current database. +#' If \code{databaseName} parameter is specified, this must be an unqualified name. #' @param databaseName (optional) name of the database #' @return a SparkDataFrame of the list of column descriptions. #' @rdname listColumns @@ -409,12 +413,13 @@ listFunctions <- function(databaseName = NULL) { dataFrame(callJMethod(jdst, "toDF")) } -#' Recover all the partitions in the directory of a table and update the catalog +#' Recovers all the partitions in the directory of a table and update the catalog #' -#' Recover all the partitions in the directory of a table and update the catalog. The name should -#' reference a partitioned table, and not a temporary view. +#' Recovers all the partitions in the directory of a table and update the catalog. The name should +#' reference a partitioned table, and not a view. #' -#' @param tableName a name of the table. +#' @param tableName the qualified or unqualified name that designates a table. If no database +#' identifier is provided, it refers to a table in the current database. #' @rdname recoverPartitions #' @name recoverPartitions #' @export @@ -430,17 +435,18 @@ recoverPartitions <- function(tableName) { invisible(handledCallJMethod(catalog, "recoverPartitions", tableName)) } -#' Invalidate and refresh all the cached metadata of the given table +#' Invalidates and refreshes all the cached data and metadata of the given table #' -#' Invalidate and refresh all the cached metadata of the given table. For performance reasons, -#' Spark SQL or the external data source library it uses might cache certain metadata about a -#' table, such as the location of blocks. When those change outside of Spark SQL, users should +#' Invalidates and refreshes all the cached data and metadata of the given table. For performance +#' reasons, Spark SQL or the external data source library it uses might cache certain metadata about +#' a table, such as the location of blocks. When those change outside of Spark SQL, users should #' call this function to invalidate the cache. #' #' If this table is cached as an InMemoryRelation, drop the original cached version and make the #' new version cached lazily. #' -#' @param tableName a name of the table. +#' @param tableName the qualified or unqualified name that designates a table. If no database +#' identifier is provided, it refers to a table in the current database. #' @rdname refreshTable #' @name refreshTable #' @export @@ -456,11 +462,11 @@ refreshTable <- function(tableName) { invisible(handledCallJMethod(catalog, "refreshTable", tableName)) } -#' Invalidate and refresh all the cached data and metadata for SparkDataFrame containing path +#' Invalidates and refreshes all the cached data and metadata for SparkDataFrame containing path #' -#' Invalidate and refresh all the cached data (and the associated metadata) for any SparkDataFrame -#' that contains the given data source path. Path matching is by prefix, i.e. "/" would invalidate -#' everything that is cached. +#' Invalidates and refreshes all the cached data (and the associated metadata) for any +#' SparkDataFrame that contains the given data source path. Path matching is by prefix, i.e. "/" +#' would invalidate everything that is cached. #' #' @param path the path of the data source. #' @rdname refreshByPath diff --git a/python/pyspark/sql/catalog.py b/python/pyspark/sql/catalog.py index 253a75062917..41e68a45a615 100644 --- a/python/pyspark/sql/catalog.py +++ b/python/pyspark/sql/catalog.py @@ -72,10 +72,10 @@ def listDatabases(self): @ignore_unicode_prefix @since(2.0) def listTables(self, dbName=None): - """Returns a list of tables in the specified database. + """Returns a list of tables/views in the specified database. If no database is specified, the current database is used. - This includes all temporary tables. + This includes all temporary views. """ if dbName is None: dbName = self.currentDatabase() @@ -115,7 +115,7 @@ def listFunctions(self, dbName=None): @ignore_unicode_prefix @since(2.0) def listColumns(self, tableName, dbName=None): - """Returns a list of columns for the given table in the specified database. + """Returns a list of columns for the given table/view in the specified database. If no database is specified, the current database is used. @@ -161,14 +161,15 @@ def createExternalTable(self, tableName, path=None, source=None, schema=None, ** def createTable(self, tableName, path=None, source=None, schema=None, **options): """Creates a table based on the dataset in a data source. - It returns the DataFrame associated with the external table. + It returns the DataFrame associated with the table. The data source is specified by the ``source`` and a set of ``options``. If ``source`` is not specified, the default data source configured by - ``spark.sql.sources.default`` will be used. + ``spark.sql.sources.default`` will be used. When ``path`` is specified, an external table is + created from the data at the given path. Otherwise a managed table is created. Optionally, a schema can be provided as the schema of the returned :class:`DataFrame` and - created external table. + created table. :return: :class:`DataFrame` """ @@ -276,14 +277,24 @@ def clearCache(self): @since(2.0) def refreshTable(self, tableName): - """Invalidate and refresh all the cached metadata of the given table.""" + """Invalidates and refreshes all the cached data and metadata of the given table.""" self._jcatalog.refreshTable(tableName) @since('2.1.1') def recoverPartitions(self, tableName): - """Recover all the partitions of the given table and update the catalog.""" + """Recovers all the partitions of the given table and update the catalog. + + Only works with a partitioned table, and not a view. + """ self._jcatalog.recoverPartitions(tableName) + @since('2.2.0') + def refreshByPath(self, path): + """Invalidates and refreshes all the cached data (and the associated metadata) for any + DataFrame that contains the given data source path. + """ + self._jcatalog.refreshByPath(path) + def _reset(self): """(Internal use only) Drop all existing databases (except "default"), tables, partitions and functions, and set the current database to "default". diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index c22f4b87e1a7..fdb7abbad4e5 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -385,7 +385,7 @@ def sql(self, sqlQuery): @since(1.0) def table(self, tableName): - """Returns the specified table as a :class:`DataFrame`. + """Returns the specified table or view as a :class:`DataFrame`. :return: :class:`DataFrame` diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala index 137b0cbc84f8..074952ff7900 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala @@ -283,7 +283,7 @@ abstract class Catalog { /** * :: Experimental :: - * Creates a table from the given path based on a data source and a set of options. + * Creates a table based on the dataset in a data source and a set of options. * Then, returns the corresponding DataFrame. * * @param tableName is either a qualified or unqualified name that designates a table. @@ -321,7 +321,7 @@ abstract class Catalog { /** * :: Experimental :: * (Scala-specific) - * Creates a table from the given path based on a data source and a set of options. + * Creates a table based on the dataset in a data source and a set of options. * Then, returns the corresponding DataFrame. * * @param tableName is either a qualified or unqualified name that designates a table. @@ -357,7 +357,7 @@ abstract class Catalog { /** * :: Experimental :: - * Create a table from the given path based on a data source, a schema and a set of options. + * Create a table based on the dataset in a data source, a schema and a set of options. * Then, returns the corresponding DataFrame. * * @param tableName is either a qualified or unqualified name that designates a table. @@ -397,7 +397,7 @@ abstract class Catalog { /** * :: Experimental :: * (Scala-specific) - * Create a table from the given path based on a data source, a schema and a set of options. + * Create a table based on the dataset in a data source, a schema and a set of options. * Then, returns the corresponding DataFrame. * * @param tableName is either a qualified or unqualified name that designates a table. @@ -447,6 +447,7 @@ abstract class Catalog { /** * Recovers all the partitions in the directory of a table and update the catalog. + * Only works with a partitioned table, and not a view. * * @param tableName is either a qualified or unqualified name that designates a table. * If no database identifier is provided, it refers to a table in the @@ -493,10 +494,10 @@ abstract class Catalog { def clearCache(): Unit /** - * Invalidates and refreshes all the cached metadata of the given table. For performance reasons, - * Spark SQL or the external data source library it uses might cache certain metadata about a - * table, such as the location of blocks. When those change outside of Spark SQL, users should - * call this function to invalidate the cache. + * Invalidates and refreshes all the cached data and metadata of the given table. For performance + * reasons, Spark SQL or the external data source library it uses might cache certain metadata + * about a table, such as the location of blocks. When those change outside of Spark SQL, users + * should call this function to invalidate the cache. * * If this table is cached as an InMemoryRelation, drop the original cached version and make the * new version cached lazily. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala index 5d1c35aba529..aebb663df5c9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala @@ -141,7 +141,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { } /** - * Returns a list of columns for the given table temporary view. + * Returns a list of columns for the given table/view or temporary view. */ @throws[AnalysisException]("table does not exist") override def listColumns(tableName: String): Dataset[Column] = { @@ -150,7 +150,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { } /** - * Returns a list of columns for the given table in the specified database. + * Returns a list of columns for the given table/view or temporary view in the specified database. */ @throws[AnalysisException]("database or table does not exist") override def listColumns(dbName: String, tableName: String): Dataset[Column] = { @@ -273,7 +273,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { /** * :: Experimental :: - * Creates a table from the given path based on a data source and returns the corresponding + * Creates a table from the given path and returns the corresponding * DataFrame. * * @group ddl_ops @@ -287,7 +287,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { /** * :: Experimental :: * (Scala-specific) - * Creates a table from the given path based on a data source and a set of options. + * Creates a table based on the dataset in a data source and a set of options. * Then, returns the corresponding DataFrame. * * @group ddl_ops @@ -304,7 +304,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { /** * :: Experimental :: * (Scala-specific) - * Creates a table from the given path based on a data source, a schema and a set of options. + * Creates a table based on the dataset in a data source, a schema and a set of options. * Then, returns the corresponding DataFrame. * * @group ddl_ops @@ -367,6 +367,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { /** * Recovers all the partitions in the directory of a table and update the catalog. + * Only works with a partitioned table, and not a temporary view. * * @param tableName is either a qualified or unqualified name that designates a table. * If no database identifier is provided, it refers to a table in the @@ -431,8 +432,12 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { } /** - * Refreshes the cache entry for a table or view, if any. For Hive metastore table, the metadata - * is refreshed. For data source tables, the schema will not be inferred and refreshed. + * Invalidates and refreshes all the cached data and metadata of the given table or view. + * For Hive metastore table, the metadata is refreshed. For data source tables, the schema will + * not be inferred and refreshed. + * + * If this table is cached as an InMemoryRelation, drop the original cached version and make the + * new version cached lazily. * * @group cachemgmt * @since 2.0.0 @@ -456,7 +461,8 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { /** * Refreshes the cache entry and the associated metadata for all Dataset (if any), that contain - * the given data source path. + * the given data source path. Path matching is by prefix, i.e. "/" would invalidate + * everything that is cached. * * @group cachemgmt * @since 2.0.0 From 5a693b4138d4ce948e3bcdbe28d5c01d5deb8fa9 Mon Sep 17 00:00:00 2001 From: Felix Cheung Date: Thu, 6 Apr 2017 09:15:13 -0700 Subject: [PATCH 227/512] [SPARK-20195][SPARKR][SQL] add createTable catalog API and deprecate createExternalTable ## What changes were proposed in this pull request? Following up on #17483, add createTable (which is new in 2.2.0) and deprecate createExternalTable, plus a number of minor fixes ## How was this patch tested? manual, unit tests Author: Felix Cheung Closes #17511 from felixcheung/rceatetable. --- R/pkg/NAMESPACE | 1 + R/pkg/R/DataFrame.R | 4 +- R/pkg/R/catalog.R | 59 +++++++++++++++++++---- R/pkg/inst/tests/testthat/test_sparkSQL.R | 20 ++++++-- 4 files changed, 68 insertions(+), 16 deletions(-) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 9b7e95ce30ac..ca45c6f9b0a9 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -361,6 +361,7 @@ export("as.DataFrame", "clearCache", "createDataFrame", "createExternalTable", + "createTable", "currentDatabase", "dropTempTable", "dropTempView", diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 97786df4ae6a..ec85f723c08c 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -557,7 +557,7 @@ setMethod("insertInto", jmode <- convertToJSaveMode(ifelse(overwrite, "overwrite", "append")) write <- callJMethod(x@sdf, "write") write <- callJMethod(write, "mode", jmode) - callJMethod(write, "insertInto", tableName) + invisible(callJMethod(write, "insertInto", tableName)) }) #' Cache @@ -2894,7 +2894,7 @@ setMethod("saveAsTable", write <- callJMethod(write, "format", source) write <- callJMethod(write, "mode", jmode) write <- callJMethod(write, "options", options) - callJMethod(write, "saveAsTable", tableName) + invisible(callJMethod(write, "saveAsTable", tableName)) }) #' summary diff --git a/R/pkg/R/catalog.R b/R/pkg/R/catalog.R index 4b7f841b55dd..e59a7024333a 100644 --- a/R/pkg/R/catalog.R +++ b/R/pkg/R/catalog.R @@ -17,7 +17,7 @@ # catalog.R: SparkSession catalog functions -#' Create an external table +#' (Deprecated) Create an external table #' #' Creates an external table based on the dataset in a data source, #' Returns a SparkDataFrame associated with the external table. @@ -29,10 +29,11 @@ #' @param tableName a name of the table. #' @param path the path of files to load. #' @param source the name of external data source. -#' @param schema the schema of the data for certain data source. +#' @param schema the schema of the data required for some data sources. #' @param ... additional argument(s) passed to the method. #' @return A SparkDataFrame. -#' @rdname createExternalTable +#' @rdname createExternalTable-deprecated +#' @seealso \link{createTable} #' @export #' @examples #'\dontrun{ @@ -43,24 +44,64 @@ #' @method createExternalTable default #' @note createExternalTable since 1.4.0 createExternalTable.default <- function(tableName, path = NULL, source = NULL, schema = NULL, ...) { + .Deprecated("createTable", old = "createExternalTable") + createTable(tableName, path, source, schema, ...) +} + +createExternalTable <- function(x, ...) { + dispatchFunc("createExternalTable(tableName, path = NULL, source = NULL, ...)", x, ...) +} + +#' Creates a table based on the dataset in a data source +#' +#' Creates a table based on the dataset in a data source. Returns a SparkDataFrame associated with +#' the table. +#' +#' The data source is specified by the \code{source} and a set of options(...). +#' If \code{source} is not specified, the default data source configured by +#' "spark.sql.sources.default" will be used. When a \code{path} is specified, an external table is +#' created from the data at the given path. Otherwise a managed table is created. +#' +#' @param tableName the qualified or unqualified name that designates a table. If no database +#' identifier is provided, it refers to a table in the current database. +#' @param path (optional) the path of files to load. +#' @param source (optional) the name of the data source. +#' @param schema (optional) the schema of the data required for some data sources. +#' @param ... additional named parameters as options for the data source. +#' @return A SparkDataFrame. +#' @rdname createTable +#' @seealso \link{createExternalTable} +#' @export +#' @examples +#'\dontrun{ +#' sparkR.session() +#' df <- createTable("myjson", path="path/to/json", source="json", schema) +#' +#' createTable("people", source = "json", schema = schema) +#' insertInto(df, "people") +#' } +#' @name createTable +#' @note createTable since 2.2.0 +createTable <- function(tableName, path = NULL, source = NULL, schema = NULL, ...) { sparkSession <- getSparkSession() options <- varargsToStrEnv(...) if (!is.null(path)) { options[["path"]] <- path } + if (is.null(source)) { + source <- getDefaultSqlSource() + } catalog <- callJMethod(sparkSession, "catalog") if (is.null(schema)) { - sdf <- callJMethod(catalog, "createExternalTable", tableName, source, options) + sdf <- callJMethod(catalog, "createTable", tableName, source, options) + } else if (class(schema) == "structType") { + sdf <- callJMethod(catalog, "createTable", tableName, source, schema$jobj, options) } else { - sdf <- callJMethod(catalog, "createExternalTable", tableName, source, schema$jobj, options) + stop("schema must be a structType.") } dataFrame(sdf) } -createExternalTable <- function(x, ...) { - dispatchFunc("createExternalTable(tableName, path = NULL, source = NULL, ...)", x, ...) -} - #' Cache Table #' #' Caches the specified table in-memory. diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index ad06711a79a7..58cf24256a94 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -281,7 +281,7 @@ test_that("create DataFrame from RDD", { setHiveContext(sc) sql("CREATE TABLE people (name string, age double, height float)") df <- read.df(jsonPathNa, "json", schema) - invisible(insertInto(df, "people")) + insertInto(df, "people") expect_equal(collect(sql("SELECT age from people WHERE name = 'Bob'"))$age, c(16)) expect_equal(collect(sql("SELECT height from people WHERE name ='Bob'"))$height, @@ -1268,7 +1268,16 @@ test_that("column calculation", { test_that("test HiveContext", { setHiveContext(sc) - df <- createExternalTable("json", jsonPath, "json") + + schema <- structType(structField("name", "string"), structField("age", "integer"), + structField("height", "float")) + createTable("people", source = "json", schema = schema) + df <- read.df(jsonPathNa, "json", schema) + insertInto(df, "people") + expect_equal(collect(sql("SELECT age from people WHERE name = 'Bob'"))$age, c(16)) + sql("DROP TABLE people") + + df <- createTable("json", jsonPath, "json") expect_is(df, "SparkDataFrame") expect_equal(count(df), 3) df2 <- sql("select * from json") @@ -1276,25 +1285,26 @@ test_that("test HiveContext", { expect_equal(count(df2), 3) jsonPath2 <- tempfile(pattern = "sparkr-test", fileext = ".tmp") - invisible(saveAsTable(df, "json2", "json", "append", path = jsonPath2)) + saveAsTable(df, "json2", "json", "append", path = jsonPath2) df3 <- sql("select * from json2") expect_is(df3, "SparkDataFrame") expect_equal(count(df3), 3) unlink(jsonPath2) hivetestDataPath <- tempfile(pattern = "sparkr-test", fileext = ".tmp") - invisible(saveAsTable(df, "hivetestbl", path = hivetestDataPath)) + saveAsTable(df, "hivetestbl", path = hivetestDataPath) df4 <- sql("select * from hivetestbl") expect_is(df4, "SparkDataFrame") expect_equal(count(df4), 3) unlink(hivetestDataPath) parquetDataPath <- tempfile(pattern = "sparkr-test", fileext = ".tmp") - invisible(saveAsTable(df, "parquetest", "parquet", mode = "overwrite", path = parquetDataPath)) + saveAsTable(df, "parquetest", "parquet", mode = "overwrite", path = parquetDataPath) df5 <- sql("select * from parquetest") expect_is(df5, "SparkDataFrame") expect_equal(count(df5), 3) unlink(parquetDataPath) + unsetHiveContext() }) From a4491626ed8169f0162a0dfb78736c9b9e7fb434 Mon Sep 17 00:00:00 2001 From: jerryshao Date: Thu, 6 Apr 2017 13:23:54 -0500 Subject: [PATCH 228/512] [SPARK-17019][CORE] Expose on-heap and off-heap memory usage in various places ## What changes were proposed in this pull request? With [SPARK-13992](https://issues.apache.org/jira/browse/SPARK-13992), Spark supports persisting data into off-heap memory, but the usage of on-heap and off-heap memory is not exposed currently, it is not so convenient for user to monitor and profile, so here propose to expose off-heap memory as well as on-heap memory usage in various places: 1. Spark UI's executor page will display both on-heap and off-heap memory usage. 2. REST request returns both on-heap and off-heap memory. 3. Also this can be gotten from MetricsSystem. 4. Last this usage can be obtained programmatically from SparkListener. Attach the UI changes: ![screen shot 2016-08-12 at 11 20 44 am](https://cloud.githubusercontent.com/assets/850797/17612032/6c2f4480-607f-11e6-82e8-a27fb8cbb4ae.png) Backward compatibility is also considered for event-log and REST API. Old event log can still be replayed with off-heap usage displayed as 0. For REST API, only adds the new fields, so JSON backward compatibility can still be kept. ## How was this patch tested? Unit test added and manual verification. Author: jerryshao Closes #14617 from jerryshao/SPARK-17019. --- .../ui/static/executorspage-template.html | 18 ++- .../apache/spark/ui/static/executorspage.js | 103 ++++++++++++- .../org/apache/spark/ui/static/webui.css | 3 +- .../spark/scheduler/SparkListener.scala | 9 +- .../spark/status/api/v1/AllRDDResource.scala | 8 +- .../org/apache/spark/status/api/v1/api.scala | 12 +- .../apache/spark/storage/BlockManager.scala | 9 +- .../spark/storage/BlockManagerMaster.scala | 5 +- .../storage/BlockManagerMasterEndpoint.scala | 22 ++- .../spark/storage/BlockManagerMessages.scala | 3 +- .../spark/storage/BlockManagerSource.scala | 66 +++++---- .../spark/storage/StorageStatusListener.scala | 8 +- .../apache/spark/storage/StorageUtils.scala | 99 +++++++++---- .../apache/spark/ui/exec/ExecutorsPage.scala | 46 +++++- .../org/apache/spark/ui/storage/RDDPage.scala | 11 +- .../org/apache/spark/util/JsonProtocol.scala | 8 +- .../executor_memory_usage_expectation.json | 139 ++++++++++++++++++ ...xecutor_node_blacklisting_expectation.json | 41 ++++-- .../spark-events/app-20161116163331-0000 | 10 +- .../deploy/history/HistoryServerSuite.scala | 3 +- .../apache/spark/storage/StorageSuite.scala | 87 ++++++++++- .../org/apache/spark/ui/UISeleniumSuite.scala | 36 ++++- project/MimaExcludes.scala | 11 +- 23 files changed, 638 insertions(+), 119 deletions(-) create mode 100644 core/src/test/resources/HistoryServerExpectations/executor_memory_usage_expectation.json diff --git a/core/src/main/resources/org/apache/spark/ui/static/executorspage-template.html b/core/src/main/resources/org/apache/spark/ui/static/executorspage-template.html index 4e83d6d56498..5c91304e49fd 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/executorspage-template.html +++ b/core/src/main/resources/org/apache/spark/ui/static/executorspage-template.html @@ -24,7 +24,15 @@

    Summary

    RDD Blocks Storage Memory + title="Memory used / total available memory for storage of data like RDD partitions cached in memory.">Storage Memory + + + On Heap Storage Memory + + + Off Heap Storage Memory Disk Used Cores @@ -73,6 +81,14 @@

    Executors

    Storage Memory + + + On Heap Storage Memory + + + Off Heap Storage Memory Disk Used Cores Active Tasks diff --git a/core/src/main/resources/org/apache/spark/ui/static/executorspage.js b/core/src/main/resources/org/apache/spark/ui/static/executorspage.js index 7dbfe32de903..930a0698928d 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/executorspage.js +++ b/core/src/main/resources/org/apache/spark/ui/static/executorspage.js @@ -190,6 +190,10 @@ $(document).ready(function () { var allRDDBlocks = 0; var allMemoryUsed = 0; var allMaxMemory = 0; + var allOnHeapMemoryUsed = 0; + var allOnHeapMaxMemory = 0; + var allOffHeapMemoryUsed = 0; + var allOffHeapMaxMemory = 0; var allDiskUsed = 0; var allTotalCores = 0; var allMaxTasks = 0; @@ -208,6 +212,10 @@ $(document).ready(function () { var activeRDDBlocks = 0; var activeMemoryUsed = 0; var activeMaxMemory = 0; + var activeOnHeapMemoryUsed = 0; + var activeOnHeapMaxMemory = 0; + var activeOffHeapMemoryUsed = 0; + var activeOffHeapMaxMemory = 0; var activeDiskUsed = 0; var activeTotalCores = 0; var activeMaxTasks = 0; @@ -226,6 +234,10 @@ $(document).ready(function () { var deadRDDBlocks = 0; var deadMemoryUsed = 0; var deadMaxMemory = 0; + var deadOnHeapMemoryUsed = 0; + var deadOnHeapMaxMemory = 0; + var deadOffHeapMemoryUsed = 0; + var deadOffHeapMaxMemory = 0; var deadDiskUsed = 0; var deadTotalCores = 0; var deadMaxTasks = 0; @@ -240,11 +252,22 @@ $(document).ready(function () { var deadTotalShuffleWrite = 0; var deadTotalBlacklisted = 0; + response.forEach(function (exec) { + exec.onHeapMemoryUsed = exec.hasOwnProperty('onHeapMemoryUsed') ? exec.onHeapMemoryUsed : 0; + exec.maxOnHeapMemory = exec.hasOwnProperty('maxOnHeapMemory') ? exec.maxOnHeapMemory : 0; + exec.offHeapMemoryUsed = exec.hasOwnProperty('offHeapMemoryUsed') ? exec.offHeapMemoryUsed : 0; + exec.maxOffHeapMemory = exec.hasOwnProperty('maxOffHeapMemory') ? exec.maxOffHeapMemory : 0; + }); + response.forEach(function (exec) { allExecCnt += 1; allRDDBlocks += exec.rddBlocks; allMemoryUsed += exec.memoryUsed; allMaxMemory += exec.maxMemory; + allOnHeapMemoryUsed += exec.onHeapMemoryUsed; + allOnHeapMaxMemory += exec.maxOnHeapMemory; + allOffHeapMemoryUsed += exec.offHeapMemoryUsed; + allOffHeapMaxMemory += exec.maxOffHeapMemory; allDiskUsed += exec.diskUsed; allTotalCores += exec.totalCores; allMaxTasks += exec.maxTasks; @@ -263,6 +286,10 @@ $(document).ready(function () { activeRDDBlocks += exec.rddBlocks; activeMemoryUsed += exec.memoryUsed; activeMaxMemory += exec.maxMemory; + activeOnHeapMemoryUsed += exec.onHeapMemoryUsed; + activeOnHeapMaxMemory += exec.maxOnHeapMemory; + activeOffHeapMemoryUsed += exec.offHeapMemoryUsed; + activeOffHeapMaxMemory += exec.maxOffHeapMemory; activeDiskUsed += exec.diskUsed; activeTotalCores += exec.totalCores; activeMaxTasks += exec.maxTasks; @@ -281,6 +308,10 @@ $(document).ready(function () { deadRDDBlocks += exec.rddBlocks; deadMemoryUsed += exec.memoryUsed; deadMaxMemory += exec.maxMemory; + deadOnHeapMemoryUsed += exec.onHeapMemoryUsed; + deadOnHeapMaxMemory += exec.maxOnHeapMemory; + deadOffHeapMemoryUsed += exec.offHeapMemoryUsed; + deadOffHeapMaxMemory += exec.maxOffHeapMemory; deadDiskUsed += exec.diskUsed; deadTotalCores += exec.totalCores; deadMaxTasks += exec.maxTasks; @@ -302,6 +333,10 @@ $(document).ready(function () { "allRDDBlocks": allRDDBlocks, "allMemoryUsed": allMemoryUsed, "allMaxMemory": allMaxMemory, + "allOnHeapMemoryUsed": allOnHeapMemoryUsed, + "allOnHeapMaxMemory": allOnHeapMaxMemory, + "allOffHeapMemoryUsed": allOffHeapMemoryUsed, + "allOffHeapMaxMemory": allOffHeapMaxMemory, "allDiskUsed": allDiskUsed, "allTotalCores": allTotalCores, "allMaxTasks": allMaxTasks, @@ -321,6 +356,10 @@ $(document).ready(function () { "allRDDBlocks": activeRDDBlocks, "allMemoryUsed": activeMemoryUsed, "allMaxMemory": activeMaxMemory, + "allOnHeapMemoryUsed": activeOnHeapMemoryUsed, + "allOnHeapMaxMemory": activeOnHeapMaxMemory, + "allOffHeapMemoryUsed": activeOffHeapMemoryUsed, + "allOffHeapMaxMemory": activeOffHeapMaxMemory, "allDiskUsed": activeDiskUsed, "allTotalCores": activeTotalCores, "allMaxTasks": activeMaxTasks, @@ -340,6 +379,10 @@ $(document).ready(function () { "allRDDBlocks": deadRDDBlocks, "allMemoryUsed": deadMemoryUsed, "allMaxMemory": deadMaxMemory, + "allOnHeapMemoryUsed": deadOnHeapMemoryUsed, + "allOnHeapMaxMemory": deadOnHeapMaxMemory, + "allOffHeapMemoryUsed": deadOffHeapMemoryUsed, + "allOffHeapMaxMemory": deadOffHeapMaxMemory, "allDiskUsed": deadDiskUsed, "allTotalCores": deadTotalCores, "allMaxTasks": deadMaxTasks, @@ -378,7 +421,35 @@ $(document).ready(function () { {data: 'rddBlocks'}, { data: function (row, type) { - return type === 'display' ? (formatBytes(row.memoryUsed, type) + ' / ' + formatBytes(row.maxMemory, type)) : row.memoryUsed; + if (type !== 'display') + return row.memoryUsed; + else + return (formatBytes(row.memoryUsed, type) + ' / ' + + formatBytes(row.maxMemory, type)); + } + }, + { + data: function (row, type) { + if (type !== 'display') + return row.onHeapMemoryUsed; + else + return (formatBytes(row.onHeapMemoryUsed, type) + ' / ' + + formatBytes(row.maxOnHeapMemory, type)); + }, + "fnCreatedCell": function (nTd, sData, oData, iRow, iCol) { + $(nTd).addClass('on_heap_memory') + } + }, + { + data: function (row, type) { + if (type !== 'display') + return row.offHeapMemoryUsed; + else + return (formatBytes(row.offHeapMemoryUsed, type) + ' / ' + + formatBytes(row.maxOffHeapMemory, type)); + }, + "fnCreatedCell": function (nTd, sData, oData, iRow, iCol) { + $(nTd).addClass('off_heap_memory') } }, {data: 'diskUsed', render: formatBytes}, @@ -450,7 +521,35 @@ $(document).ready(function () { {data: 'allRDDBlocks'}, { data: function (row, type) { - return type === 'display' ? (formatBytes(row.allMemoryUsed, type) + ' / ' + formatBytes(row.allMaxMemory, type)) : row.allMemoryUsed; + if (type !== 'display') + return row.allMemoryUsed + else + return (formatBytes(row.allMemoryUsed, type) + ' / ' + + formatBytes(row.allMaxMemory, type)); + } + }, + { + data: function (row, type) { + if (type !== 'display') + return row.allOnHeapMemoryUsed; + else + return (formatBytes(row.allOnHeapMemoryUsed, type) + ' / ' + + formatBytes(row.allOnHeapMaxMemory, type)); + }, + "fnCreatedCell": function (nTd, sData, oData, iRow, iCol) { + $(nTd).addClass('on_heap_memory') + } + }, + { + data: function (row, type) { + if (type !== 'display') + return row.allOffHeapMemoryUsed; + else + return (formatBytes(row.allOffHeapMemoryUsed, type) + ' / ' + + formatBytes(row.allOffHeapMaxMemory, type)); + }, + "fnCreatedCell": function (nTd, sData, oData, iRow, iCol) { + $(nTd).addClass('off_heap_memory') } }, {data: 'allDiskUsed', render: formatBytes}, diff --git a/core/src/main/resources/org/apache/spark/ui/static/webui.css b/core/src/main/resources/org/apache/spark/ui/static/webui.css index 319a719efaa7..935d9b1aec61 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/webui.css +++ b/core/src/main/resources/org/apache/spark/ui/static/webui.css @@ -205,7 +205,8 @@ span.additional-metric-title { /* Hide all additional metrics by default. This is done here rather than using JavaScript to * avoid slow page loads for stage pages with large numbers (e.g., thousands) of tasks. */ .scheduler_delay, .deserialization_time, .fetch_wait_time, .shuffle_read_remote, -.serialization_time, .getting_result_time, .peak_execution_memory { +.serialization_time, .getting_result_time, .peak_execution_memory, +.on_heap_memory, .off_heap_memory { display: none; } diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala index 4331addb4417..bc2e53071668 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala @@ -87,8 +87,13 @@ case class SparkListenerEnvironmentUpdate(environmentDetails: Map[String, Seq[(S extends SparkListenerEvent @DeveloperApi -case class SparkListenerBlockManagerAdded(time: Long, blockManagerId: BlockManagerId, maxMem: Long) - extends SparkListenerEvent +case class SparkListenerBlockManagerAdded( + time: Long, + blockManagerId: BlockManagerId, + maxMem: Long, + maxOnHeapMem: Option[Long] = None, + maxOffHeapMem: Option[Long] = None) extends SparkListenerEvent { +} @DeveloperApi case class SparkListenerBlockManagerRemoved(time: Long, blockManagerId: BlockManagerId) diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/AllRDDResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/AllRDDResource.scala index 5c03609e5e5e..1279b281ad8d 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/AllRDDResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/AllRDDResource.scala @@ -70,7 +70,13 @@ private[spark] object AllRDDResource { address = status.blockManagerId.hostPort, memoryUsed = status.memUsedByRdd(rddId), memoryRemaining = status.memRemaining, - diskUsed = status.diskUsedByRdd(rddId) + diskUsed = status.diskUsedByRdd(rddId), + onHeapMemoryUsed = Some( + if (!rddInfo.storageLevel.useOffHeap) status.memUsedByRdd(rddId) else 0L), + offHeapMemoryUsed = Some( + if (rddInfo.storageLevel.useOffHeap) status.memUsedByRdd(rddId) else 0L), + onHeapMemoryRemaining = status.onHeapMemRemaining, + offHeapMemoryRemaining = status.offHeapMemRemaining ) } ) } else { None diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala index 5b9227350eda..d159b9450ef5 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala @@ -75,7 +75,11 @@ class ExecutorSummary private[spark]( val totalShuffleWrite: Long, val isBlacklisted: Boolean, val maxMemory: Long, - val executorLogs: Map[String, String]) + val executorLogs: Map[String, String], + val onHeapMemoryUsed: Option[Long], + val offHeapMemoryUsed: Option[Long], + val maxOnHeapMemory: Option[Long], + val maxOffHeapMemory: Option[Long]) class JobData private[spark]( val jobId: Int, @@ -111,7 +115,11 @@ class RDDDataDistribution private[spark]( val address: String, val memoryUsed: Long, val memoryRemaining: Long, - val diskUsed: Long) + val diskUsed: Long, + val onHeapMemoryUsed: Option[Long], + val offHeapMemoryUsed: Option[Long], + val onHeapMemoryRemaining: Option[Long], + val offHeapMemoryRemaining: Option[Long]) class RDDPartitionInfo private[spark]( val blockName: String, diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 46a078b2f9f9..63acba65d3c5 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -150,8 +150,8 @@ private[spark] class BlockManager( // However, since we use this only for reporting and logging, what we actually want here is // the absolute maximum value that `maxMemory` can ever possibly reach. We may need // to revisit whether reporting this value as the "max" is intuitive to the user. - private val maxMemory = - memoryManager.maxOnHeapStorageMemory + memoryManager.maxOffHeapStorageMemory + private val maxOnHeapMemory = memoryManager.maxOnHeapStorageMemory + private val maxOffHeapMemory = memoryManager.maxOffHeapStorageMemory // Port used by the external shuffle service. In Yarn mode, this may be already be // set through the Hadoop configuration as the server is launched in the Yarn NM. @@ -229,7 +229,8 @@ private[spark] class BlockManager( val idFromMaster = master.registerBlockManager( id, - maxMemory, + maxOnHeapMemory, + maxOffHeapMemory, slaveEndpoint) blockManagerId = if (idFromMaster != null) idFromMaster else id @@ -307,7 +308,7 @@ private[spark] class BlockManager( def reregister(): Unit = { // TODO: We might need to rate limit re-registering. logInfo(s"BlockManager $blockManagerId re-registering with master") - master.registerBlockManager(blockManagerId, maxMemory, slaveEndpoint) + master.registerBlockManager(blockManagerId, maxOnHeapMemory, maxOffHeapMemory, slaveEndpoint) reportAllBlocks() } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala index 3ca690db9e79..ea5d8423a588 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala @@ -57,11 +57,12 @@ class BlockManagerMaster( */ def registerBlockManager( blockManagerId: BlockManagerId, - maxMemSize: Long, + maxOnHeapMemSize: Long, + maxOffHeapMemSize: Long, slaveEndpoint: RpcEndpointRef): BlockManagerId = { logInfo(s"Registering BlockManager $blockManagerId") val updatedId = driverEndpoint.askSync[BlockManagerId]( - RegisterBlockManager(blockManagerId, maxMemSize, slaveEndpoint)) + RegisterBlockManager(blockManagerId, maxOnHeapMemSize, maxOffHeapMemSize, slaveEndpoint)) logInfo(s"Registered BlockManager $updatedId") updatedId } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala index 84c04d22600a..467c3e0e6b51 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala @@ -71,8 +71,8 @@ class BlockManagerMasterEndpoint( logInfo("BlockManagerMasterEndpoint up") override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { - case RegisterBlockManager(blockManagerId, maxMemSize, slaveEndpoint) => - context.reply(register(blockManagerId, maxMemSize, slaveEndpoint)) + case RegisterBlockManager(blockManagerId, maxOnHeapMemSize, maxOffHeapMemSize, slaveEndpoint) => + context.reply(register(blockManagerId, maxOnHeapMemSize, maxOffHeapMemSize, slaveEndpoint)) case _updateBlockInfo @ UpdateBlockInfo(blockManagerId, blockId, storageLevel, deserializedSize, size) => @@ -276,7 +276,8 @@ class BlockManagerMasterEndpoint( private def storageStatus: Array[StorageStatus] = { blockManagerInfo.map { case (blockManagerId, info) => - new StorageStatus(blockManagerId, info.maxMem, info.blocks.asScala) + new StorageStatus(blockManagerId, info.maxMem, Some(info.maxOnHeapMem), + Some(info.maxOffHeapMem), info.blocks.asScala) }.toArray } @@ -338,7 +339,8 @@ class BlockManagerMasterEndpoint( */ private def register( idWithoutTopologyInfo: BlockManagerId, - maxMemSize: Long, + maxOnHeapMemSize: Long, + maxOffHeapMemSize: Long, slaveEndpoint: RpcEndpointRef): BlockManagerId = { // the dummy id is not expected to contain the topology information. // we get that info here and respond back with a more fleshed out block manager id @@ -359,14 +361,15 @@ class BlockManagerMasterEndpoint( case None => } logInfo("Registering block manager %s with %s RAM, %s".format( - id.hostPort, Utils.bytesToString(maxMemSize), id)) + id.hostPort, Utils.bytesToString(maxOnHeapMemSize + maxOffHeapMemSize), id)) blockManagerIdByExecutor(id.executorId) = id blockManagerInfo(id) = new BlockManagerInfo( - id, System.currentTimeMillis(), maxMemSize, slaveEndpoint) + id, System.currentTimeMillis(), maxOnHeapMemSize, maxOffHeapMemSize, slaveEndpoint) } - listenerBus.post(SparkListenerBlockManagerAdded(time, id, maxMemSize)) + listenerBus.post(SparkListenerBlockManagerAdded(time, id, maxOnHeapMemSize + maxOffHeapMemSize, + Some(maxOnHeapMemSize), Some(maxOffHeapMemSize))) id } @@ -464,10 +467,13 @@ object BlockStatus { private[spark] class BlockManagerInfo( val blockManagerId: BlockManagerId, timeMs: Long, - val maxMem: Long, + val maxOnHeapMem: Long, + val maxOffHeapMem: Long, val slaveEndpoint: RpcEndpointRef) extends Logging { + val maxMem = maxOnHeapMem + maxOffHeapMem + private var _lastSeenMs: Long = timeMs private var _remainingMem: Long = maxMem diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala index 0aea438e7f47..0c0ff144596a 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala @@ -58,7 +58,8 @@ private[spark] object BlockManagerMessages { case class RegisterBlockManager( blockManagerId: BlockManagerId, - maxMemSize: Long, + maxOnHeapMemSize: Long, + maxOffHeapMemSize: Long, sender: RpcEndpointRef) extends ToBlockManagerMaster diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerSource.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerSource.scala index c5ba9af3e265..197a01762c0c 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerSource.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerSource.scala @@ -26,35 +26,39 @@ private[spark] class BlockManagerSource(val blockManager: BlockManager) override val metricRegistry = new MetricRegistry() override val sourceName = "BlockManager" - metricRegistry.register(MetricRegistry.name("memory", "maxMem_MB"), new Gauge[Long] { - override def getValue: Long = { - val storageStatusList = blockManager.master.getStorageStatus - val maxMem = storageStatusList.map(_.maxMem).sum - maxMem / 1024 / 1024 - } - }) - - metricRegistry.register(MetricRegistry.name("memory", "remainingMem_MB"), new Gauge[Long] { - override def getValue: Long = { - val storageStatusList = blockManager.master.getStorageStatus - val remainingMem = storageStatusList.map(_.memRemaining).sum - remainingMem / 1024 / 1024 - } - }) - - metricRegistry.register(MetricRegistry.name("memory", "memUsed_MB"), new Gauge[Long] { - override def getValue: Long = { - val storageStatusList = blockManager.master.getStorageStatus - val memUsed = storageStatusList.map(_.memUsed).sum - memUsed / 1024 / 1024 - } - }) - - metricRegistry.register(MetricRegistry.name("disk", "diskSpaceUsed_MB"), new Gauge[Long] { - override def getValue: Long = { - val storageStatusList = blockManager.master.getStorageStatus - val diskSpaceUsed = storageStatusList.map(_.diskUsed).sum - diskSpaceUsed / 1024 / 1024 - } - }) + private def registerGauge(name: String, func: BlockManagerMaster => Long): Unit = { + metricRegistry.register(name, new Gauge[Long] { + override def getValue: Long = func(blockManager.master) / 1024 / 1024 + }) + } + + registerGauge(MetricRegistry.name("memory", "maxMem_MB"), + _.getStorageStatus.map(_.maxMem).sum) + + registerGauge(MetricRegistry.name("memory", "maxOnHeapMem_MB"), + _.getStorageStatus.map(_.maxOnHeapMem.getOrElse(0L)).sum) + + registerGauge(MetricRegistry.name("memory", "maxOffHeapMem_MB"), + _.getStorageStatus.map(_.maxOffHeapMem.getOrElse(0L)).sum) + + registerGauge(MetricRegistry.name("memory", "remainingMem_MB"), + _.getStorageStatus.map(_.memRemaining).sum) + + registerGauge(MetricRegistry.name("memory", "remainingOnHeapMem_MB"), + _.getStorageStatus.map(_.onHeapMemRemaining.getOrElse(0L)).sum) + + registerGauge(MetricRegistry.name("memory", "remainingOffHeapMem_MB"), + _.getStorageStatus.map(_.offHeapMemRemaining.getOrElse(0L)).sum) + + registerGauge(MetricRegistry.name("memory", "memUsed_MB"), + _.getStorageStatus.map(_.memUsed).sum) + + registerGauge(MetricRegistry.name("memory", "onHeapMemUsed_MB"), + _.getStorageStatus.map(_.onHeapMemUsed.getOrElse(0L)).sum) + + registerGauge(MetricRegistry.name("memory", "offHeapMemUsed_MB"), + _.getStorageStatus.map(_.offHeapMemUsed.getOrElse(0L)).sum) + + registerGauge(MetricRegistry.name("disk", "diskSpaceUsed_MB"), + _.getStorageStatus.map(_.diskUsed).sum) } diff --git a/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala b/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala index 798658a15b79..1b30d4fa93bc 100644 --- a/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala +++ b/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala @@ -41,7 +41,7 @@ class StorageStatusListener(conf: SparkConf) extends SparkListener { } def deadStorageStatusList: Seq[StorageStatus] = synchronized { - deadExecutorStorageStatus.toSeq + deadExecutorStorageStatus } /** Update storage status list to reflect updated block statuses */ @@ -74,8 +74,10 @@ class StorageStatusListener(conf: SparkConf) extends SparkListener { synchronized { val blockManagerId = blockManagerAdded.blockManagerId val executorId = blockManagerId.executorId - val maxMem = blockManagerAdded.maxMem - val storageStatus = new StorageStatus(blockManagerId, maxMem) + // The onHeap and offHeap memory are always defined for new applications, + // but they can be missing if we are replaying old event logs. + val storageStatus = new StorageStatus(blockManagerId, blockManagerAdded.maxMem, + blockManagerAdded.maxOnHeapMem, blockManagerAdded.maxOffHeapMem) executorIdToStorageStatus(executorId) = storageStatus // Try to remove the dead storage status if same executor register the block manager twice. diff --git a/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala b/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala index 241aacd74b58..8f0d181fc8fe 100644 --- a/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala +++ b/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala @@ -35,7 +35,11 @@ import org.apache.spark.internal.Logging * class cannot mutate the source of the information. Accesses are not thread-safe. */ @DeveloperApi -class StorageStatus(val blockManagerId: BlockManagerId, val maxMem: Long) { +class StorageStatus( + val blockManagerId: BlockManagerId, + val maxMemory: Long, + val maxOnHeapMem: Option[Long], + val maxOffHeapMem: Option[Long]) { /** * Internal representation of the blocks stored in this block manager. @@ -46,25 +50,21 @@ class StorageStatus(val blockManagerId: BlockManagerId, val maxMem: Long) { private val _rddBlocks = new mutable.HashMap[Int, mutable.Map[BlockId, BlockStatus]] private val _nonRddBlocks = new mutable.HashMap[BlockId, BlockStatus] - /** - * Storage information of the blocks that entails memory, disk, and off-heap memory usage. - * - * As with the block maps, we store the storage information separately for RDD blocks and - * non-RDD blocks for the same reason. In particular, RDD storage information is stored - * in a map indexed by the RDD ID to the following 4-tuple: - * - * (memory size, disk size, storage level) - * - * We assume that all the blocks that belong to the same RDD have the same storage level. - * This field is not relevant to non-RDD blocks, however, so the storage information for - * non-RDD blocks contains only the first 3 fields (in the same order). - */ - private val _rddStorageInfo = new mutable.HashMap[Int, (Long, Long, StorageLevel)] - private var _nonRddStorageInfo: (Long, Long) = (0L, 0L) + private case class RddStorageInfo(memoryUsage: Long, diskUsage: Long, level: StorageLevel) + private val _rddStorageInfo = new mutable.HashMap[Int, RddStorageInfo] + + private case class NonRddStorageInfo(var onHeapUsage: Long, var offHeapUsage: Long, + var diskUsage: Long) + private val _nonRddStorageInfo = NonRddStorageInfo(0L, 0L, 0L) /** Create a storage status with an initial set of blocks, leaving the source unmodified. */ - def this(bmid: BlockManagerId, maxMem: Long, initialBlocks: Map[BlockId, BlockStatus]) { - this(bmid, maxMem) + def this( + bmid: BlockManagerId, + maxMemory: Long, + maxOnHeapMem: Option[Long], + maxOffHeapMem: Option[Long], + initialBlocks: Map[BlockId, BlockStatus]) { + this(bmid, maxMemory, maxOnHeapMem, maxOffHeapMem) initialBlocks.foreach { case (bid, bstatus) => addBlock(bid, bstatus) } } @@ -176,26 +176,57 @@ class StorageStatus(val blockManagerId: BlockManagerId, val maxMem: Long) { */ def numRddBlocksById(rddId: Int): Int = _rddBlocks.get(rddId).map(_.size).getOrElse(0) + /** Return the max memory can be used by this block manager. */ + def maxMem: Long = maxMemory + /** Return the memory remaining in this block manager. */ def memRemaining: Long = maxMem - memUsed + /** Return the memory used by caching RDDs */ + def cacheSize: Long = onHeapCacheSize.getOrElse(0L) + offHeapCacheSize.getOrElse(0L) + /** Return the memory used by this block manager. */ - def memUsed: Long = _nonRddStorageInfo._1 + cacheSize + def memUsed: Long = onHeapMemUsed.getOrElse(0L) + offHeapMemUsed.getOrElse(0L) - /** Return the memory used by caching RDDs */ - def cacheSize: Long = _rddBlocks.keys.toSeq.map(memUsedByRdd).sum + /** Return the on-heap memory remaining in this block manager. */ + def onHeapMemRemaining: Option[Long] = + for (m <- maxOnHeapMem; o <- onHeapMemUsed) yield m - o + + /** Return the off-heap memory remaining in this block manager. */ + def offHeapMemRemaining: Option[Long] = + for (m <- maxOffHeapMem; o <- offHeapMemUsed) yield m - o + + /** Return the on-heap memory used by this block manager. */ + def onHeapMemUsed: Option[Long] = onHeapCacheSize.map(_ + _nonRddStorageInfo.onHeapUsage) + + /** Return the off-heap memory used by this block manager. */ + def offHeapMemUsed: Option[Long] = offHeapCacheSize.map(_ + _nonRddStorageInfo.offHeapUsage) + + /** Return the memory used by on-heap caching RDDs */ + def onHeapCacheSize: Option[Long] = maxOnHeapMem.map { _ => + _rddStorageInfo.collect { + case (_, storageInfo) if !storageInfo.level.useOffHeap => storageInfo.memoryUsage + }.sum + } + + /** Return the memory used by off-heap caching RDDs */ + def offHeapCacheSize: Option[Long] = maxOffHeapMem.map { _ => + _rddStorageInfo.collect { + case (_, storageInfo) if storageInfo.level.useOffHeap => storageInfo.memoryUsage + }.sum + } /** Return the disk space used by this block manager. */ - def diskUsed: Long = _nonRddStorageInfo._2 + _rddBlocks.keys.toSeq.map(diskUsedByRdd).sum + def diskUsed: Long = _nonRddStorageInfo.diskUsage + _rddBlocks.keys.toSeq.map(diskUsedByRdd).sum /** Return the memory used by the given RDD in this block manager in O(1) time. */ - def memUsedByRdd(rddId: Int): Long = _rddStorageInfo.get(rddId).map(_._1).getOrElse(0L) + def memUsedByRdd(rddId: Int): Long = _rddStorageInfo.get(rddId).map(_.memoryUsage).getOrElse(0L) /** Return the disk space used by the given RDD in this block manager in O(1) time. */ - def diskUsedByRdd(rddId: Int): Long = _rddStorageInfo.get(rddId).map(_._2).getOrElse(0L) + def diskUsedByRdd(rddId: Int): Long = _rddStorageInfo.get(rddId).map(_.diskUsage).getOrElse(0L) /** Return the storage level, if any, used by the given RDD in this block manager. */ - def rddStorageLevel(rddId: Int): Option[StorageLevel] = _rddStorageInfo.get(rddId).map(_._3) + def rddStorageLevel(rddId: Int): Option[StorageLevel] = _rddStorageInfo.get(rddId).map(_.level) /** * Update the relevant storage info, taking into account any existing status for this block. @@ -210,10 +241,12 @@ class StorageStatus(val blockManagerId: BlockManagerId, val maxMem: Long) { val (oldMem, oldDisk) = blockId match { case RDDBlockId(rddId, _) => _rddStorageInfo.get(rddId) - .map { case (mem, disk, _) => (mem, disk) } + .map { case RddStorageInfo(mem, disk, _) => (mem, disk) } .getOrElse((0L, 0L)) - case _ => - _nonRddStorageInfo + case _ if !level.useOffHeap => + (_nonRddStorageInfo.onHeapUsage, _nonRddStorageInfo.diskUsage) + case _ if level.useOffHeap => + (_nonRddStorageInfo.offHeapUsage, _nonRddStorageInfo.diskUsage) } val newMem = math.max(oldMem + changeInMem, 0L) val newDisk = math.max(oldDisk + changeInDisk, 0L) @@ -225,13 +258,17 @@ class StorageStatus(val blockManagerId: BlockManagerId, val maxMem: Long) { if (newMem + newDisk == 0) { _rddStorageInfo.remove(rddId) } else { - _rddStorageInfo(rddId) = (newMem, newDisk, level) + _rddStorageInfo(rddId) = RddStorageInfo(newMem, newDisk, level) } case _ => - _nonRddStorageInfo = (newMem, newDisk) + if (!level.useOffHeap) { + _nonRddStorageInfo.onHeapUsage = newMem + } else { + _nonRddStorageInfo.offHeapUsage = newMem + } + _nonRddStorageInfo.diskUsage = newDisk } } - } /** Helper methods for storage-related objects. */ diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala index d849ce76a9e3..0a3c63d14ca8 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala @@ -40,7 +40,8 @@ private[ui] case class ExecutorSummaryInfo( totalShuffleRead: Long, totalShuffleWrite: Long, isBlacklisted: Int, - maxMemory: Long, + maxOnHeapMem: Long, + maxOffHeapMem: Long, executorLogs: Map[String, String]) @@ -53,6 +54,34 @@ private[ui] class ExecutorsPage( val content =
    { +
    + + + Show Additional Metrics + + +
    ++
    ++ ++ ++ @@ -65,6 +94,11 @@ private[ui] class ExecutorsPage( } private[spark] object ExecutorsPage { + private val ON_HEAP_MEMORY_TOOLTIP = "Memory used / total available memory for on heap " + + "storage of data like RDD partitions cached in memory." + private val OFF_HEAP_MEMORY_TOOLTIP = "Memory used / total available memory for off heap " + + "storage of data like RDD partitions cached in memory." + /** Represent an executor's info as a map given a storage status index */ def getExecInfo( listener: ExecutorsListener, @@ -80,6 +114,10 @@ private[spark] object ExecutorsPage { val rddBlocks = status.numBlocks val memUsed = status.memUsed val maxMem = status.maxMem + val onHeapMemUsed = status.onHeapMemUsed + val offHeapMemUsed = status.offHeapMemUsed + val maxOnHeapMem = status.maxOnHeapMem + val maxOffHeapMem = status.maxOffHeapMem val diskUsed = status.diskUsed val taskSummary = listener.executorToTaskSummary.getOrElse(execId, ExecutorTaskSummary(execId)) @@ -103,7 +141,11 @@ private[spark] object ExecutorsPage { taskSummary.shuffleWrite, taskSummary.isBlacklisted, maxMem, - taskSummary.executorLogs + taskSummary.executorLogs, + onHeapMemUsed, + offHeapMemUsed, + maxOnHeapMem, + maxOffHeapMem ) } } diff --git a/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala b/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala index 227e940c9c50..a1a0c729b924 100644 --- a/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala @@ -147,7 +147,8 @@ private[ui] class RDDPage(parent: StorageTab) extends WebUIPage("rdd") { /** Header fields for the worker table */ private def workerHeader = Seq( "Host", - "Memory Usage", + "On Heap Memory Usage", + "Off Heap Memory Usage", "Disk Usage") /** Render an HTML row representing a worker */ @@ -155,8 +156,12 @@ private[ui] class RDDPage(parent: StorageTab) extends WebUIPage("rdd") { {worker.address} - {Utils.bytesToString(worker.memoryUsed)} - ({Utils.bytesToString(worker.memoryRemaining)} Remaining) + {Utils.bytesToString(worker.onHeapMemoryUsed.getOrElse(0L))} + ({Utils.bytesToString(worker.onHeapMemoryRemaining.getOrElse(0L))} Remaining) + + + {Utils.bytesToString(worker.offHeapMemoryUsed.getOrElse(0L))} + ({Utils.bytesToString(worker.offHeapMemoryRemaining.getOrElse(0L))} Remaining) {Utils.bytesToString(worker.diskUsed)} diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index 1d2cb7acefa3..8296c4294242 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -182,7 +182,9 @@ private[spark] object JsonProtocol { ("Event" -> SPARK_LISTENER_EVENT_FORMATTED_CLASS_NAMES.blockManagerAdded) ~ ("Block Manager ID" -> blockManagerId) ~ ("Maximum Memory" -> blockManagerAdded.maxMem) ~ - ("Timestamp" -> blockManagerAdded.time) + ("Timestamp" -> blockManagerAdded.time) ~ + ("Maximum Onheap Memory" -> blockManagerAdded.maxOnHeapMem) ~ + ("Maximum Offheap Memory" -> blockManagerAdded.maxOffHeapMem) } def blockManagerRemovedToJson(blockManagerRemoved: SparkListenerBlockManagerRemoved): JValue = { @@ -612,7 +614,9 @@ private[spark] object JsonProtocol { val blockManagerId = blockManagerIdFromJson(json \ "Block Manager ID") val maxMem = (json \ "Maximum Memory").extract[Long] val time = Utils.jsonOption(json \ "Timestamp").map(_.extract[Long]).getOrElse(-1L) - SparkListenerBlockManagerAdded(time, blockManagerId, maxMem) + val maxOnHeapMem = Utils.jsonOption(json \ "Maximum Onheap Memory").map(_.extract[Long]) + val maxOffHeapMem = Utils.jsonOption(json \ "Maximum Offheap Memory").map(_.extract[Long]) + SparkListenerBlockManagerAdded(time, blockManagerId, maxMem, maxOnHeapMem, maxOffHeapMem) } def blockManagerRemovedFromJson(json: JValue): SparkListenerBlockManagerRemoved = { diff --git a/core/src/test/resources/HistoryServerExpectations/executor_memory_usage_expectation.json b/core/src/test/resources/HistoryServerExpectations/executor_memory_usage_expectation.json new file mode 100644 index 000000000000..e732af266350 --- /dev/null +++ b/core/src/test/resources/HistoryServerExpectations/executor_memory_usage_expectation.json @@ -0,0 +1,139 @@ +[ { + "id" : "2", + "hostPort" : "172.22.0.167:51487", + "isActive" : true, + "rddBlocks" : 0, + "memoryUsed" : 0, + "diskUsed" : 0, + "totalCores" : 4, + "maxTasks" : 4, + "activeTasks" : 0, + "failedTasks" : 4, + "completedTasks" : 0, + "totalTasks" : 4, + "totalDuration" : 2537, + "totalGCTime" : 88, + "totalInputBytes" : 0, + "totalShuffleRead" : 0, + "totalShuffleWrite" : 0, + "isBlacklisted" : true, + "maxMemory" : 908381388, + "executorLogs" : { + "stdout" : "http://172.22.0.167:51469/logPage/?appId=app-20161116163331-0000&executorId=2&logType=stdout", + "stderr" : "http://172.22.0.167:51469/logPage/?appId=app-20161116163331-0000&executorId=2&logType=stderr" + }, + "onHeapMemoryUsed" : 0, + "offHeapMemoryUsed" : 0, + "maxOnHeapMemory" : 384093388, + "maxOffHeapMemory" : 524288000 +}, { + "id" : "driver", + "hostPort" : "172.22.0.167:51475", + "isActive" : true, + "rddBlocks" : 0, + "memoryUsed" : 0, + "diskUsed" : 0, + "totalCores" : 0, + "maxTasks" : 0, + "activeTasks" : 0, + "failedTasks" : 0, + "completedTasks" : 0, + "totalTasks" : 0, + "totalDuration" : 0, + "totalGCTime" : 0, + "totalInputBytes" : 0, + "totalShuffleRead" : 0, + "totalShuffleWrite" : 0, + "isBlacklisted" : true, + "maxMemory" : 908381388, + "executorLogs" : { }, + "onHeapMemoryUsed" : 0, + "offHeapMemoryUsed" : 0, + "maxOnHeapMemory" : 384093388, + "maxOffHeapMemory" : 524288000 +}, { + "id" : "1", + "hostPort" : "172.22.0.167:51490", + "isActive" : true, + "rddBlocks" : 0, + "memoryUsed" : 0, + "diskUsed" : 0, + "totalCores" : 4, + "maxTasks" : 4, + "activeTasks" : 0, + "failedTasks" : 0, + "completedTasks" : 4, + "totalTasks" : 4, + "totalDuration" : 3152, + "totalGCTime" : 68, + "totalInputBytes" : 0, + "totalShuffleRead" : 0, + "totalShuffleWrite" : 0, + "isBlacklisted" : true, + "maxMemory" : 908381388, + "executorLogs" : { + "stdout" : "http://172.22.0.167:51467/logPage/?appId=app-20161116163331-0000&executorId=1&logType=stdout", + "stderr" : "http://172.22.0.167:51467/logPage/?appId=app-20161116163331-0000&executorId=1&logType=stderr" + }, + + "onHeapMemoryUsed" : 0, + "offHeapMemoryUsed" : 0, + "maxOnHeapMemory" : 384093388, + "maxOffHeapMemory" : 524288000 +}, { + "id" : "0", + "hostPort" : "172.22.0.167:51491", + "isActive" : true, + "rddBlocks" : 0, + "memoryUsed" : 0, + "diskUsed" : 0, + "totalCores" : 4, + "maxTasks" : 4, + "activeTasks" : 0, + "failedTasks" : 4, + "completedTasks" : 0, + "totalTasks" : 4, + "totalDuration" : 2551, + "totalGCTime" : 116, + "totalInputBytes" : 0, + "totalShuffleRead" : 0, + "totalShuffleWrite" : 0, + "isBlacklisted" : true, + "maxMemory" : 908381388, + "executorLogs" : { + "stdout" : "http://172.22.0.167:51465/logPage/?appId=app-20161116163331-0000&executorId=0&logType=stdout", + "stderr" : "http://172.22.0.167:51465/logPage/?appId=app-20161116163331-0000&executorId=0&logType=stderr" + }, + "onHeapMemoryUsed" : 0, + "offHeapMemoryUsed" : 0, + "maxOnHeapMemory" : 384093388, + "maxOffHeapMemory" : 524288000 +}, { + "id" : "3", + "hostPort" : "172.22.0.167:51485", + "isActive" : true, + "rddBlocks" : 0, + "memoryUsed" : 0, + "diskUsed" : 0, + "totalCores" : 4, + "maxTasks" : 4, + "activeTasks" : 0, + "failedTasks" : 0, + "completedTasks" : 12, + "totalTasks" : 12, + "totalDuration" : 2453, + "totalGCTime" : 72, + "totalInputBytes" : 0, + "totalShuffleRead" : 0, + "totalShuffleWrite" : 0, + "isBlacklisted" : true, + "maxMemory" : 908381388, + "executorLogs" : { + "stdout" : "http://172.22.0.167:51466/logPage/?appId=app-20161116163331-0000&executorId=3&logType=stdout", + "stderr" : "http://172.22.0.167:51466/logPage/?appId=app-20161116163331-0000&executorId=3&logType=stderr" + }, + "onHeapMemoryUsed" : 0, + "offHeapMemoryUsed" : 0, + "maxOnHeapMemory" : 384093388, + "maxOffHeapMemory" : 524288000 +} ] diff --git a/core/src/test/resources/HistoryServerExpectations/executor_node_blacklisting_expectation.json b/core/src/test/resources/HistoryServerExpectations/executor_node_blacklisting_expectation.json index 5914a1c2c4b6..e732af266350 100644 --- a/core/src/test/resources/HistoryServerExpectations/executor_node_blacklisting_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/executor_node_blacklisting_expectation.json @@ -17,11 +17,15 @@ "totalShuffleRead" : 0, "totalShuffleWrite" : 0, "isBlacklisted" : true, - "maxMemory" : 384093388, + "maxMemory" : 908381388, "executorLogs" : { "stdout" : "http://172.22.0.167:51469/logPage/?appId=app-20161116163331-0000&executorId=2&logType=stdout", "stderr" : "http://172.22.0.167:51469/logPage/?appId=app-20161116163331-0000&executorId=2&logType=stderr" - } + }, + "onHeapMemoryUsed" : 0, + "offHeapMemoryUsed" : 0, + "maxOnHeapMemory" : 384093388, + "maxOffHeapMemory" : 524288000 }, { "id" : "driver", "hostPort" : "172.22.0.167:51475", @@ -41,8 +45,12 @@ "totalShuffleRead" : 0, "totalShuffleWrite" : 0, "isBlacklisted" : true, - "maxMemory" : 384093388, - "executorLogs" : { } + "maxMemory" : 908381388, + "executorLogs" : { }, + "onHeapMemoryUsed" : 0, + "offHeapMemoryUsed" : 0, + "maxOnHeapMemory" : 384093388, + "maxOffHeapMemory" : 524288000 }, { "id" : "1", "hostPort" : "172.22.0.167:51490", @@ -62,11 +70,16 @@ "totalShuffleRead" : 0, "totalShuffleWrite" : 0, "isBlacklisted" : true, - "maxMemory" : 384093388, + "maxMemory" : 908381388, "executorLogs" : { "stdout" : "http://172.22.0.167:51467/logPage/?appId=app-20161116163331-0000&executorId=1&logType=stdout", "stderr" : "http://172.22.0.167:51467/logPage/?appId=app-20161116163331-0000&executorId=1&logType=stderr" - } + }, + + "onHeapMemoryUsed" : 0, + "offHeapMemoryUsed" : 0, + "maxOnHeapMemory" : 384093388, + "maxOffHeapMemory" : 524288000 }, { "id" : "0", "hostPort" : "172.22.0.167:51491", @@ -86,11 +99,15 @@ "totalShuffleRead" : 0, "totalShuffleWrite" : 0, "isBlacklisted" : true, - "maxMemory" : 384093388, + "maxMemory" : 908381388, "executorLogs" : { "stdout" : "http://172.22.0.167:51465/logPage/?appId=app-20161116163331-0000&executorId=0&logType=stdout", "stderr" : "http://172.22.0.167:51465/logPage/?appId=app-20161116163331-0000&executorId=0&logType=stderr" - } + }, + "onHeapMemoryUsed" : 0, + "offHeapMemoryUsed" : 0, + "maxOnHeapMemory" : 384093388, + "maxOffHeapMemory" : 524288000 }, { "id" : "3", "hostPort" : "172.22.0.167:51485", @@ -110,9 +127,13 @@ "totalShuffleRead" : 0, "totalShuffleWrite" : 0, "isBlacklisted" : true, - "maxMemory" : 384093388, + "maxMemory" : 908381388, "executorLogs" : { "stdout" : "http://172.22.0.167:51466/logPage/?appId=app-20161116163331-0000&executorId=3&logType=stdout", "stderr" : "http://172.22.0.167:51466/logPage/?appId=app-20161116163331-0000&executorId=3&logType=stderr" - } + }, + "onHeapMemoryUsed" : 0, + "offHeapMemoryUsed" : 0, + "maxOnHeapMemory" : 384093388, + "maxOffHeapMemory" : 524288000 } ] diff --git a/core/src/test/resources/spark-events/app-20161116163331-0000 b/core/src/test/resources/spark-events/app-20161116163331-0000 index 7566c9fc0a20..57cfc5b97312 100755 --- a/core/src/test/resources/spark-events/app-20161116163331-0000 +++ b/core/src/test/resources/spark-events/app-20161116163331-0000 @@ -1,15 +1,15 @@ {"Event":"SparkListenerLogStart","Spark Version":"2.1.0-SNAPSHOT"} -{"Event":"SparkListenerBlockManagerAdded","Block Manager ID":{"Executor ID":"driver","Host":"172.22.0.167","Port":51475},"Maximum Memory":384093388,"Timestamp":1479335611477} +{"Event":"SparkListenerBlockManagerAdded","Block Manager ID":{"Executor ID":"driver","Host":"172.22.0.167","Port":51475},"Maximum Memory":908381388,"Timestamp":1479335611477,"Maximum Onheap Memory":384093388,"Maximum Offheap Memory":524288000} {"Event":"SparkListenerEnvironmentUpdate","JVM Information":{"Java Home":"/Library/Java/JavaVirtualMachines/jdk1.8.0_92.jdk/Contents/Home/jre","Java Version":"1.8.0_92 (Oracle Corporation)","Scala Version":"version 2.11.8"},"Spark Properties":{"spark.blacklist.task.maxTaskAttemptsPerExecutor":"3","spark.blacklist.enabled":"TRUE","spark.driver.host":"172.22.0.167","spark.blacklist.task.maxTaskAttemptsPerNode":"3","spark.eventLog.enabled":"TRUE","spark.driver.port":"51459","spark.repl.class.uri":"spark://172.22.0.167:51459/classes","spark.jars":"","spark.repl.class.outputDir":"/private/var/folders/l4/d46wlzj16593f3d812vk49tw0000gp/T/spark-1cbc97d0-7fe6-4c9f-8c2c-f6fe51ee3cf2/repl-39929169-ac4c-4c6d-b116-f648e4dd62ed","spark.app.name":"Spark shell","spark.blacklist.stage.maxFailedExecutorsPerNode":"3","spark.scheduler.mode":"FIFO","spark.eventLog.overwrite":"TRUE","spark.blacklist.stage.maxFailedTasksPerExecutor":"3","spark.executor.id":"driver","spark.blacklist.application.maxFailedExecutorsPerNode":"2","spark.submit.deployMode":"client","spark.master":"local-cluster[4,4,1024]","spark.home":"/Users/Jose/IdeaProjects/spark","spark.eventLog.dir":"/Users/jose/logs","spark.sql.catalogImplementation":"in-memory","spark.eventLog.compress":"FALSE","spark.blacklist.application.maxFailedTasksPerExecutor":"1","spark.blacklist.timeout":"1000000","spark.app.id":"app-20161116163331-0000","spark.task.maxFailures":"4"},"System Properties":{"java.io.tmpdir":"/var/folders/l4/d46wlzj16593f3d812vk49tw0000gp/T/","line.separator":"\n","path.separator":":","sun.management.compiler":"HotSpot 64-Bit Tiered Compilers","SPARK_SUBMIT":"true","sun.cpu.endian":"little","java.specification.version":"1.8","java.vm.specification.name":"Java Virtual Machine Specification","java.vendor":"Oracle Corporation","java.vm.specification.version":"1.8","user.home":"/Users/Jose","file.encoding.pkg":"sun.io","sun.nio.ch.bugLevel":"","ftp.nonProxyHosts":"local|*.local|169.254/16|*.169.254/16","sun.arch.data.model":"64","sun.boot.library.path":"/Library/Java/JavaVirtualMachines/jdk1.8.0_92.jdk/Contents/Home/jre/lib","user.dir":"/Users/Jose/IdeaProjects/spark","java.library.path":"/Users/Jose/Library/Java/Extensions:/Library/Java/Extensions:/Network/Library/Java/Extensions:/System/Library/Java/Extensions:/usr/lib/java:.","sun.cpu.isalist":"","os.arch":"x86_64","java.vm.version":"25.92-b14","java.endorsed.dirs":"/Library/Java/JavaVirtualMachines/jdk1.8.0_92.jdk/Contents/Home/jre/lib/endorsed","java.runtime.version":"1.8.0_92-b14","java.vm.info":"mixed mode","java.ext.dirs":"/Users/Jose/Library/Java/Extensions:/Library/Java/JavaVirtualMachines/jdk1.8.0_92.jdk/Contents/Home/jre/lib/ext:/Library/Java/Extensions:/Network/Library/Java/Extensions:/System/Library/Java/Extensions:/usr/lib/java","java.runtime.name":"Java(TM) SE Runtime Environment","file.separator":"/","io.netty.maxDirectMemory":"0","java.class.version":"52.0","scala.usejavacp":"true","java.specification.name":"Java Platform API Specification","sun.boot.class.path":"/Library/Java/JavaVirtualMachines/jdk1.8.0_92.jdk/Contents/Home/jre/lib/resources.jar:/Library/Java/JavaVirtualMachines/jdk1.8.0_92.jdk/Contents/Home/jre/lib/rt.jar:/Library/Java/JavaVirtualMachines/jdk1.8.0_92.jdk/Contents/Home/jre/lib/sunrsasign.jar:/Library/Java/JavaVirtualMachines/jdk1.8.0_92.jdk/Contents/Home/jre/lib/jsse.jar:/Library/Java/JavaVirtualMachines/jdk1.8.0_92.jdk/Contents/Home/jre/lib/jce.jar:/Library/Java/JavaVirtualMachines/jdk1.8.0_92.jdk/Contents/Home/jre/lib/charsets.jar:/Library/Java/JavaVirtualMachines/jdk1.8.0_92.jdk/Contents/Home/jre/lib/jfr.jar:/Library/Java/JavaVirtualMachines/jdk1.8.0_92.jdk/Contents/Home/jre/classes","file.encoding":"UTF-8","user.timezone":"America/Chicago","java.specification.vendor":"Oracle Corporation","sun.java.launcher":"SUN_STANDARD","os.version":"10.11.6","sun.os.patch.level":"unknown","gopherProxySet":"false","java.vm.specification.vendor":"Oracle Corporation","user.country":"US","sun.jnu.encoding":"UTF-8","http.nonProxyHosts":"local|*.local|169.254/16|*.169.254/16","user.language":"en","socksNonProxyHosts":"local|*.local|169.254/16|*.169.254/16","java.vendor.url":"http://java.oracle.com/","java.awt.printerjob":"sun.lwawt.macosx.CPrinterJob","java.awt.graphicsenv":"sun.awt.CGraphicsEnvironment","awt.toolkit":"sun.lwawt.macosx.LWCToolkit","os.name":"Mac OS X","java.vm.vendor":"Oracle Corporation","java.vendor.url.bug":"http://bugreport.sun.com/bugreport/","user.name":"jose","java.vm.name":"Java HotSpot(TM) 64-Bit Server VM","sun.java.command":"org.apache.spark.deploy.SparkSubmit --master local-cluster[4,4,1024] --conf spark.blacklist.enabled=TRUE --conf spark.blacklist.timeout=1000000 --conf spark.blacklist.application.maxFailedTasksPerExecutor=1 --conf spark.eventLog.overwrite=TRUE --conf spark.blacklist.task.maxTaskAttemptsPerNode=3 --conf spark.blacklist.stage.maxFailedTasksPerExecutor=3 --conf spark.blacklist.task.maxTaskAttemptsPerExecutor=3 --conf spark.eventLog.compress=FALSE --conf spark.blacklist.stage.maxFailedExecutorsPerNode=3 --conf spark.eventLog.enabled=TRUE --conf spark.eventLog.dir=/Users/jose/logs --conf spark.blacklist.application.maxFailedExecutorsPerNode=2 --conf spark.task.maxFailures=4 --class org.apache.spark.repl.Main --name Spark shell spark-shell -i /Users/Jose/dev/jose-utils/blacklist/test-blacklist.scala","java.home":"/Library/Java/JavaVirtualMachines/jdk1.8.0_92.jdk/Contents/Home/jre","java.version":"1.8.0_92","sun.io.unicode.encoding":"UnicodeBig"},"Classpath Entries":{"/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/avro-mapred-1.7.7-hadoop2.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/hadoop-mapreduce-client-core-2.2.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jetty-servlet-9.2.16.v20160414.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/parquet-column-1.8.1.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/snappy-java-1.1.2.6.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/oro-2.0.8.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/arpack_combined_all-0.1.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/pmml-schema-1.2.15.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/spark-assembly_2.11-2.1.0-SNAPSHOT.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/javassist-3.18.1-GA.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/spark-tags_2.11-2.1.0-SNAPSHOT.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/spark-launcher_2.11-2.1.0-SNAPSHOT.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/commons-math3-3.4.1.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/hk2-api-2.4.0-b34.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/scala-xml_2.11-1.0.4.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/objenesis-2.1.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/spire-macros_2.11-0.7.4.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/scala-reflect-2.11.8.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/spark-mllib-local_2.11-2.1.0-SNAPSHOT.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/spark-mllib_2.11-2.1.0-SNAPSHOT.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jersey-server-2.22.2.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/core/target/scala-2.11/classes/":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jackson-mapper-asl-1.9.13.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jackson-module-scala_2.11-2.6.5.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/curator-framework-2.4.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/javax.inject-1.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/curator-client-2.4.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jackson-core-asl-1.9.13.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/common/network-common/target/scala-2.11/classes/":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/zookeeper-3.4.5.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/hadoop-auth-2.2.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/repl/target/scala-2.11/classes/":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jul-to-slf4j-1.7.16.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jersey-media-jaxb-2.22.2.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jetty-io-9.2.16.v20160414.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/RoaringBitmap-0.5.11.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/javax.ws.rs-api-2.0.1.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/sql/catalyst/target/scala-2.11/classes/":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/spark-unsafe_2.11-2.1.0-SNAPSHOT.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/spark-repl_2.11-2.1.0-SNAPSHOT.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jetty-continuation-9.2.16.v20160414.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/hadoop-yarn-client-2.2.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/sql/hive-thriftserver/target/scala-2.11/classes":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/hadoop-annotations-2.2.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/metrics-graphite-3.1.2.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/hadoop-yarn-api-2.2.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jersey-container-servlet-core-2.22.2.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/streaming/target/scala-2.11/classes/":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/commons-net-3.1.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jetty-proxy-9.2.16.v20160414.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/spark-catalyst_2.11-2.1.0-SNAPSHOT.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/lz4-1.3.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/commons-crypto-1.0.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/common/network-yarn/target/scala-2.11/classes":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/javax.annotation-api-1.2.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/spark-sql_2.11-2.1.0-SNAPSHOT.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/guava-14.0.1.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/javax.servlet-api-3.1.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/commons-collections-3.2.1.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/conf/":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/unused-1.0.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/aopalliance-1.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/parquet-encoding-1.8.1.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/common/tags/target/scala-2.11/classes/":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/json4s-jackson_2.11-3.2.11.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/commons-cli-1.2.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/hadoop-yarn-server-common-2.2.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/cglib-2.2.1-v20090111.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/pyrolite-4.13.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/scala-library-2.11.8.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/scala-parser-combinators_2.11-1.0.4.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jetty-util-6.1.26.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/py4j-0.10.4.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/commons-configuration-1.6.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/core-1.1.2.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/core/target/jars/*":"System Classpath","/Users/Jose/IdeaProjects/spark/common/network-shuffle/target/scala-2.11/classes/":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/parquet-format-2.3.0-incubating.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/kryo-shaded-3.0.3.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/sql/core/target/scala-2.11/classes/":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/chill-java-0.8.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jackson-annotations-2.6.5.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/parquet-hadoop-1.8.1.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/sql/hive/target/scala-2.11/classes/":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/avro-ipc-1.7.7.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/xz-1.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/parquet-jackson-1.8.1.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/aopalliance-repackaged-2.4.0-b34.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jersey-common-2.22.2.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/log4j-1.2.17.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/metrics-core-3.1.2.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jetty-util-9.2.16.v20160414.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/scalap-2.11.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/osgi-resource-locator-1.0.1.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/commons-beanutils-1.7.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/commons-compress-1.4.1.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jcl-over-slf4j-1.7.16.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/yarn/target/scala-2.11/classes":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jetty-plus-9.2.16.v20160414.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/protobuf-java-2.5.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/common/unsafe/target/scala-2.11/classes/":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jackson-module-paranamer-2.6.5.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/leveldbjni-all-1.8.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jackson-core-2.6.5.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/slf4j-api-1.7.16.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/compress-lzf-1.0.3.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/stream-2.7.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/hadoop-mapreduce-client-shuffle-2.2.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/commons-codec-1.10.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/hadoop-yarn-common-2.2.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/common/sketch/target/scala-2.11/classes/":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/breeze_2.11-0.12.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/hadoop-mapreduce-client-common-2.2.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/spark-core_2.11-2.1.0-SNAPSHOT.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jersey-container-servlet-2.22.2.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/spark-network-shuffle_2.11-2.1.0-SNAPSHOT.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/commons-lang-2.5.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/ivy-2.4.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/hadoop-common-2.2.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/commons-math-2.1.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/hadoop-hdfs-2.2.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/scala-compiler-2.11.8.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/metrics-jvm-3.1.2.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/commons-lang3-3.5.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jsr305-1.3.9.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/minlog-1.3.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/netty-3.8.0.Final.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jetty-webapp-9.2.16.v20160414.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/json4s-ast_2.11-3.2.11.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/xbean-asm5-shaded-4.4.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/commons-io-2.1.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/slf4j-log4j12-1.7.16.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/hk2-locator-2.4.0-b34.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/shapeless_2.11-2.0.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/spark-network-common_2.11-2.1.0-SNAPSHOT.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jetty-xml-9.2.16.v20160414.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/commons-httpclient-3.1.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/javax.inject-2.4.0-b34.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/mllib/target/scala-2.11/classes/":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/scalatest_2.11-2.2.6.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/hk2-utils-2.4.0-b34.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jetty-client-9.2.16.v20160414.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jersey-guava-2.22.2.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jetty-jndi-9.2.16.v20160414.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/graphx/target/scala-2.11/classes/":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/hadoop-mapreduce-client-app-2.2.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/examples/target/scala-2.11/classes/":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/xmlenc-0.52.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jets3t-0.7.1.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/curator-recipes-2.4.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/opencsv-2.3.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jtransforms-2.4.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/antlr4-runtime-4.5.3.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/chill_2.11-0.8.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/commons-digester-1.8.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/univocity-parsers-2.2.1.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jline-2.12.1.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/spark-streaming_2.11-2.1.0-SNAPSHOT.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/launcher/target/scala-2.11/classes/":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/breeze-macros_2.11-0.12.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jersey-client-2.22.2.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jackson-databind-2.6.5.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jetty-servlets-9.2.16.v20160414.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/paranamer-2.6.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jetty-security-9.2.16.v20160414.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/avro-ipc-1.7.7-tests.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/avro-1.7.7.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/spire_2.11-0.7.4.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/hadoop-client-2.2.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/metrics-json-3.1.2.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/commons-beanutils-core-1.8.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/validation-api-1.1.0.Final.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/spark-graphx_2.11-2.1.0-SNAPSHOT.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/netty-all-4.0.41.Final.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/janino-3.0.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/json4s-core_2.11-3.2.11.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/commons-compiler-3.0.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/guice-3.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jetty-server-9.2.16.v20160414.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jetty-http-9.2.16.v20160414.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/parquet-common-1.8.1.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/hadoop-mapreduce-client-jobclient-2.2.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/spark-sketch_2.11-2.1.0-SNAPSHOT.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/pmml-model-1.2.15.jar":"System Classpath"}} {"Event":"SparkListenerApplicationStart","App Name":"Spark shell","App ID":"app-20161116163331-0000","Timestamp":1479335609916,"User":"jose"} {"Event":"SparkListenerExecutorAdded","Timestamp":1479335615320,"Executor ID":"3","Executor Info":{"Host":"172.22.0.167","Total Cores":4,"Log Urls":{"stdout":"http://172.22.0.167:51466/logPage/?appId=app-20161116163331-0000&executorId=3&logType=stdout","stderr":"http://172.22.0.167:51466/logPage/?appId=app-20161116163331-0000&executorId=3&logType=stderr"}}} -{"Event":"SparkListenerBlockManagerAdded","Block Manager ID":{"Executor ID":"3","Host":"172.22.0.167","Port":51485},"Maximum Memory":384093388,"Timestamp":1479335615387} +{"Event":"SparkListenerBlockManagerAdded","Block Manager ID":{"Executor ID":"3","Host":"172.22.0.167","Port":51485},"Maximum Memory":908381388,"Timestamp":1479335615387,"Maximum Onheap Memory":384093388,"Maximum Offheap Memory":524288000} {"Event":"SparkListenerExecutorAdded","Timestamp":1479335615393,"Executor ID":"2","Executor Info":{"Host":"172.22.0.167","Total Cores":4,"Log Urls":{"stdout":"http://172.22.0.167:51469/logPage/?appId=app-20161116163331-0000&executorId=2&logType=stdout","stderr":"http://172.22.0.167:51469/logPage/?appId=app-20161116163331-0000&executorId=2&logType=stderr"}}} {"Event":"SparkListenerExecutorAdded","Timestamp":1479335615443,"Executor ID":"1","Executor Info":{"Host":"172.22.0.167","Total Cores":4,"Log Urls":{"stdout":"http://172.22.0.167:51467/logPage/?appId=app-20161116163331-0000&executorId=1&logType=stdout","stderr":"http://172.22.0.167:51467/logPage/?appId=app-20161116163331-0000&executorId=1&logType=stderr"}}} -{"Event":"SparkListenerBlockManagerAdded","Block Manager ID":{"Executor ID":"2","Host":"172.22.0.167","Port":51487},"Maximum Memory":384093388,"Timestamp":1479335615448} +{"Event":"SparkListenerBlockManagerAdded","Block Manager ID":{"Executor ID":"2","Host":"172.22.0.167","Port":51487},"Maximum Memory":908381388,"Timestamp":1479335615448,"Maximum Onheap Memory":384093388,"Maximum Offheap Memory":524288000} {"Event":"SparkListenerExecutorAdded","Timestamp":1479335615462,"Executor ID":"0","Executor Info":{"Host":"172.22.0.167","Total Cores":4,"Log Urls":{"stdout":"http://172.22.0.167:51465/logPage/?appId=app-20161116163331-0000&executorId=0&logType=stdout","stderr":"http://172.22.0.167:51465/logPage/?appId=app-20161116163331-0000&executorId=0&logType=stderr"}}} -{"Event":"SparkListenerBlockManagerAdded","Block Manager ID":{"Executor ID":"1","Host":"172.22.0.167","Port":51490},"Maximum Memory":384093388,"Timestamp":1479335615496} -{"Event":"SparkListenerBlockManagerAdded","Block Manager ID":{"Executor ID":"0","Host":"172.22.0.167","Port":51491},"Maximum Memory":384093388,"Timestamp":1479335615515} +{"Event":"SparkListenerBlockManagerAdded","Block Manager ID":{"Executor ID":"1","Host":"172.22.0.167","Port":51490},"Maximum Memory":908381388,"Timestamp":1479335615496,"Maximum Onheap Memory":384093388,"Maximum Offheap Memory":524288000} +{"Event":"SparkListenerBlockManagerAdded","Block Manager ID":{"Executor ID":"0","Host":"172.22.0.167","Port":51491},"Maximum Memory":908381388,"Timestamp":1479335615515,"Maximum Onheap Memory":384093388,"Maximum Offheap Memory":524288000} {"Event":"SparkListenerJobStart","Job ID":0,"Submission Time":1479335616467,"Stage Infos":[{"Stage ID":0,"Stage Attempt ID":0,"Stage Name":"count at :26","Number of Tasks":16,"RDD Info":[{"RDD ID":1,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"1\",\"name\":\"map\"}","Callsite":"map at :26","Parent IDs":[0],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":16,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":0,"Name":"ParallelCollectionRDD","Scope":"{\"id\":\"0\",\"name\":\"parallelize\"}","Callsite":"parallelize at :26","Parent IDs":[],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":16,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0}],"Parent IDs":[],"Details":"org.apache.spark.rdd.RDD.count(RDD.scala:1135)\n$line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:26)\n$line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:31)\n$line16.$read$$iw$$iw$$iw$$iw$$iw$$iw.(:33)\n$line16.$read$$iw$$iw$$iw$$iw$$iw.(:35)\n$line16.$read$$iw$$iw$$iw$$iw.(:37)\n$line16.$read$$iw$$iw$$iw.(:39)\n$line16.$read$$iw$$iw.(:41)\n$line16.$read$$iw.(:43)\n$line16.$read.(:45)\n$line16.$read$.(:49)\n$line16.$read$.()\n$line16.$eval$.$print$lzycompute(:7)\n$line16.$eval$.$print(:6)\n$line16.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:498)\nscala.tools.nsc.interpreter.IMain$ReadEvalPrint.call(IMain.scala:786)","Accumulables":[]}],"Stage IDs":[0],"Properties":{}} {"Event":"SparkListenerStageSubmitted","Stage Info":{"Stage ID":0,"Stage Attempt ID":0,"Stage Name":"count at :26","Number of Tasks":16,"RDD Info":[{"RDD ID":1,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"1\",\"name\":\"map\"}","Callsite":"map at :26","Parent IDs":[0],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":16,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":0,"Name":"ParallelCollectionRDD","Scope":"{\"id\":\"0\",\"name\":\"parallelize\"}","Callsite":"parallelize at :26","Parent IDs":[],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":16,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0}],"Parent IDs":[],"Details":"org.apache.spark.rdd.RDD.count(RDD.scala:1135)\n$line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:26)\n$line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:31)\n$line16.$read$$iw$$iw$$iw$$iw$$iw$$iw.(:33)\n$line16.$read$$iw$$iw$$iw$$iw$$iw.(:35)\n$line16.$read$$iw$$iw$$iw$$iw.(:37)\n$line16.$read$$iw$$iw$$iw.(:39)\n$line16.$read$$iw$$iw.(:41)\n$line16.$read$$iw.(:43)\n$line16.$read.(:45)\n$line16.$read$.(:49)\n$line16.$read$.()\n$line16.$eval$.$print$lzycompute(:7)\n$line16.$eval$.$print(:6)\n$line16.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:498)\nscala.tools.nsc.interpreter.IMain$ReadEvalPrint.call(IMain.scala:786)","Accumulables":[]},"Properties":{}} {"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":0,"Index":0,"Attempt":0,"Launch Time":1479335616657,"Executor ID":"1","Host":"172.22.0.167","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} diff --git a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala index dcf83cb530a9..764156c3edc4 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala @@ -153,7 +153,8 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers "rdd list storage json" -> "applications/local-1422981780767/storage/rdd", "executor node blacklisting" -> "applications/app-20161116163331-0000/executors", - "executor node blacklisting unblacklisting" -> "applications/app-20161115172038-0000/executors" + "executor node blacklisting unblacklisting" -> "applications/app-20161115172038-0000/executors", + "executor memory usage" -> "applications/app-20161116163331-0000/executors" // Todo: enable this test when logging the even of onBlockUpdated. See: SPARK-13845 // "one rdd storage json" -> "applications/local-1422981780767/storage/rdd/0" ) diff --git a/core/src/test/scala/org/apache/spark/storage/StorageSuite.scala b/core/src/test/scala/org/apache/spark/storage/StorageSuite.scala index e5733aebf607..da198f946fd6 100644 --- a/core/src/test/scala/org/apache/spark/storage/StorageSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/StorageSuite.scala @@ -27,7 +27,7 @@ class StorageSuite extends SparkFunSuite { // For testing add, update, and remove (for non-RDD blocks) private def storageStatus1: StorageStatus = { - val status = new StorageStatus(BlockManagerId("big", "dog", 1), 1000L) + val status = new StorageStatus(BlockManagerId("big", "dog", 1), 1000L, Some(1000L), Some(0L)) assert(status.blocks.isEmpty) assert(status.rddBlocks.isEmpty) assert(status.memUsed === 0L) @@ -74,7 +74,7 @@ class StorageSuite extends SparkFunSuite { // For testing add, update, remove, get, and contains etc. for both RDD and non-RDD blocks private def storageStatus2: StorageStatus = { - val status = new StorageStatus(BlockManagerId("big", "dog", 1), 1000L) + val status = new StorageStatus(BlockManagerId("big", "dog", 1), 1000L, Some(1000L), Some(0L)) assert(status.rddBlocks.isEmpty) status.addBlock(TestBlockId("dan"), BlockStatus(memAndDisk, 10L, 20L)) status.addBlock(TestBlockId("man"), BlockStatus(memAndDisk, 10L, 20L)) @@ -252,9 +252,9 @@ class StorageSuite extends SparkFunSuite { // For testing StorageUtils.updateRddInfo and StorageUtils.getRddBlockLocations private def stockStorageStatuses: Seq[StorageStatus] = { - val status1 = new StorageStatus(BlockManagerId("big", "dog", 1), 1000L) - val status2 = new StorageStatus(BlockManagerId("fat", "duck", 2), 2000L) - val status3 = new StorageStatus(BlockManagerId("fat", "cat", 3), 3000L) + val status1 = new StorageStatus(BlockManagerId("big", "dog", 1), 1000L, Some(1000L), Some(0L)) + val status2 = new StorageStatus(BlockManagerId("fat", "duck", 2), 2000L, Some(2000L), Some(0L)) + val status3 = new StorageStatus(BlockManagerId("fat", "cat", 3), 3000L, Some(3000L), Some(0L)) status1.addBlock(RDDBlockId(0, 0), BlockStatus(memAndDisk, 1L, 2L)) status1.addBlock(RDDBlockId(0, 1), BlockStatus(memAndDisk, 1L, 2L)) status2.addBlock(RDDBlockId(0, 2), BlockStatus(memAndDisk, 1L, 2L)) @@ -332,4 +332,81 @@ class StorageSuite extends SparkFunSuite { assert(blockLocations1(RDDBlockId(1, 2)) === Seq("cat:3")) } + private val offheap = StorageLevel.OFF_HEAP + // For testing add, update, remove, get, and contains etc. for both RDD and non-RDD onheap + // and offheap blocks + private def storageStatus3: StorageStatus = { + val status = new StorageStatus(BlockManagerId("big", "dog", 1), 2000L, Some(1000L), Some(1000L)) + assert(status.rddBlocks.isEmpty) + status.addBlock(TestBlockId("dan"), BlockStatus(memAndDisk, 10L, 20L)) + status.addBlock(TestBlockId("man"), BlockStatus(offheap, 10L, 0L)) + status.addBlock(RDDBlockId(0, 0), BlockStatus(offheap, 10L, 0L)) + status.addBlock(RDDBlockId(1, 1), BlockStatus(offheap, 100L, 0L)) + status.addBlock(RDDBlockId(2, 2), BlockStatus(memAndDisk, 10L, 20L)) + status.addBlock(RDDBlockId(2, 3), BlockStatus(memAndDisk, 10L, 20L)) + status.addBlock(RDDBlockId(2, 4), BlockStatus(memAndDisk, 10L, 40L)) + status + } + + test("storage memUsed, diskUsed with on-heap and off-heap blocks") { + val status = storageStatus3 + def actualMemUsed: Long = status.blocks.values.map(_.memSize).sum + def actualDiskUsed: Long = status.blocks.values.map(_.diskSize).sum + + def actualOnHeapMemUsed: Long = + status.blocks.values.filter(!_.storageLevel.useOffHeap).map(_.memSize).sum + def actualOffHeapMemUsed: Long = + status.blocks.values.filter(_.storageLevel.useOffHeap).map(_.memSize).sum + + assert(status.maxMem === status.maxOnHeapMem.get + status.maxOffHeapMem.get) + + assert(status.memUsed === actualMemUsed) + assert(status.diskUsed === actualDiskUsed) + assert(status.onHeapMemUsed.get === actualOnHeapMemUsed) + assert(status.offHeapMemUsed.get === actualOffHeapMemUsed) + + assert(status.memRemaining === status.maxMem - actualMemUsed) + assert(status.onHeapMemRemaining.get === status.maxOnHeapMem.get - actualOnHeapMemUsed) + assert(status.offHeapMemRemaining.get === status.maxOffHeapMem.get - actualOffHeapMemUsed) + + status.addBlock(TestBlockId("wire"), BlockStatus(memAndDisk, 400L, 500L)) + status.addBlock(RDDBlockId(25, 25), BlockStatus(memAndDisk, 40L, 50L)) + assert(status.memUsed === actualMemUsed) + assert(status.diskUsed === actualDiskUsed) + + status.updateBlock(TestBlockId("dan"), BlockStatus(memAndDisk, 4L, 5L)) + status.updateBlock(RDDBlockId(0, 0), BlockStatus(offheap, 4L, 0L)) + status.updateBlock(RDDBlockId(1, 1), BlockStatus(offheap, 4L, 0L)) + assert(status.memUsed === actualMemUsed) + assert(status.diskUsed === actualDiskUsed) + assert(status.onHeapMemUsed.get === actualOnHeapMemUsed) + assert(status.offHeapMemUsed.get === actualOffHeapMemUsed) + + status.removeBlock(TestBlockId("fire")) + status.removeBlock(TestBlockId("man")) + status.removeBlock(RDDBlockId(2, 2)) + status.removeBlock(RDDBlockId(2, 3)) + assert(status.memUsed === actualMemUsed) + assert(status.diskUsed === actualDiskUsed) + } + + private def storageStatus4: StorageStatus = { + val status = new StorageStatus(BlockManagerId("big", "dog", 1), 2000L, None, None) + status + } + test("old SparkListenerBlockManagerAdded event compatible") { + // This scenario will only be happened when replaying old event log. In this scenario there's + // no block add or remove event replayed, so only total amount of memory is valid. + val status = storageStatus4 + assert(status.maxMem === status.maxMemory) + + assert(status.memUsed === 0L) + assert(status.diskUsed === 0L) + assert(status.onHeapMemUsed === None) + assert(status.offHeapMemUsed === None) + + assert(status.memRemaining === status.maxMem) + assert(status.onHeapMemRemaining === None) + assert(status.offHeapMemRemaining === None) + } } diff --git a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala index 422837303642..f4c561c73779 100644 --- a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala @@ -39,7 +39,7 @@ import org.apache.spark.LocalSparkContext._ import org.apache.spark.api.java.StorageLevels import org.apache.spark.deploy.history.HistoryServerSuite import org.apache.spark.shuffle.FetchFailedException -import org.apache.spark.status.api.v1.{JacksonMessageWriter, StageStatus} +import org.apache.spark.status.api.v1.{JacksonMessageWriter, RDDDataDistribution, StageStatus} private[spark] class SparkUICssErrorHandler extends DefaultCssErrorHandler { @@ -103,6 +103,7 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B .set("spark.ui.enabled", "true") .set("spark.ui.port", "0") .set("spark.ui.killEnabled", killEnabled.toString) + .set("spark.memory.offHeap.size", "64m") val sc = new SparkContext(conf) assert(sc.ui.isDefined) sc @@ -151,6 +152,39 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B val updatedRddJson = getJson(ui, "storage/rdd/0") (updatedRddJson \ "storageLevel").extract[String] should be ( StorageLevels.MEMORY_ONLY.description) + + val dataDistributions0 = + (updatedRddJson \ "dataDistribution").extract[Seq[RDDDataDistribution]] + dataDistributions0.length should be (1) + val dist0 = dataDistributions0.head + + dist0.onHeapMemoryUsed should not be (None) + dist0.memoryUsed should be (dist0.onHeapMemoryUsed.get) + dist0.onHeapMemoryRemaining should not be (None) + dist0.offHeapMemoryRemaining should not be (None) + dist0.memoryRemaining should be ( + dist0.onHeapMemoryRemaining.get + dist0.offHeapMemoryRemaining.get) + dist0.onHeapMemoryUsed should not be (Some(0L)) + dist0.offHeapMemoryUsed should be (Some(0L)) + + rdd.unpersist() + rdd.persist(StorageLevels.OFF_HEAP).count() + val updatedStorageJson1 = getJson(ui, "storage/rdd") + updatedStorageJson1.children.length should be (1) + val updatedRddJson1 = getJson(ui, "storage/rdd/0") + val dataDistributions1 = + (updatedRddJson1 \ "dataDistribution").extract[Seq[RDDDataDistribution]] + dataDistributions1.length should be (1) + val dist1 = dataDistributions1.head + + dist1.offHeapMemoryUsed should not be (None) + dist1.memoryUsed should be (dist1.offHeapMemoryUsed.get) + dist1.onHeapMemoryRemaining should not be (None) + dist1.offHeapMemoryRemaining should not be (None) + dist1.memoryRemaining should be ( + dist1.onHeapMemoryRemaining.get + dist1.offHeapMemoryRemaining.get) + dist1.onHeapMemoryUsed should be (Some(0L)) + dist1.offHeapMemoryUsed should not be (Some(0L)) } } diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 2e3f9f2d0f3a..feae76a087de 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -100,7 +100,16 @@ object MimaExcludes { ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Matrix.toDenseMatrix"), ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Matrix.toSparseMatrix"), ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Matrix.getSizeInBytes") - ) + ) ++ Seq( + // [SPARK-17019] Expose on-heap and off-heap memory usage in various places + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.scheduler.SparkListenerBlockManagerAdded.copy"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.scheduler.SparkListenerBlockManagerAdded.this"), + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.scheduler.SparkListenerBlockManagerAdded$"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.scheduler.SparkListenerBlockManagerAdded.apply"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.storage.StorageStatus.this"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.StorageStatus.this"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.RDDDataDistribution.this") + ) // Exclude rules for 2.1.x lazy val v21excludes = v20excludes ++ { From 8129d59d0e389fa8074958f1b90f7539e3e79bb7 Mon Sep 17 00:00:00 2001 From: Dustin Koupal Date: Thu, 6 Apr 2017 16:56:36 -0700 Subject: [PATCH 229/512] [MINOR][DOCS] Fix typo in Hive Examples ## What changes were proposed in this pull request? Fix typo in hive examples from "DaraFrames" to "DataFrames" ## How was this patch tested? N/A Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Dustin Koupal Closes #17554 from cooper6581/typo-daraframes. --- .../apache/spark/examples/sql/hive/JavaSparkHiveExample.java | 2 +- examples/src/main/python/sql/hive.py | 2 +- .../org/apache/spark/examples/sql/hive/SparkHiveExample.scala | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/src/main/java/org/apache/spark/examples/sql/hive/JavaSparkHiveExample.java b/examples/src/main/java/org/apache/spark/examples/sql/hive/JavaSparkHiveExample.java index 47638565b166..575a463e8725 100644 --- a/examples/src/main/java/org/apache/spark/examples/sql/hive/JavaSparkHiveExample.java +++ b/examples/src/main/java/org/apache/spark/examples/sql/hive/JavaSparkHiveExample.java @@ -89,7 +89,7 @@ public static void main(String[] args) { // The results of SQL queries are themselves DataFrames and support all normal functions. Dataset sqlDF = spark.sql("SELECT key, value FROM src WHERE key < 10 ORDER BY key"); - // The items in DaraFrames are of type Row, which lets you to access each column by ordinal. + // The items in DataFrames are of type Row, which lets you to access each column by ordinal. Dataset stringsDS = sqlDF.map( (MapFunction) row -> "Key: " + row.get(0) + ", Value: " + row.get(1), Encoders.STRING()); diff --git a/examples/src/main/python/sql/hive.py b/examples/src/main/python/sql/hive.py index 1f175d725800..1f83a6fb48b9 100644 --- a/examples/src/main/python/sql/hive.py +++ b/examples/src/main/python/sql/hive.py @@ -68,7 +68,7 @@ # The results of SQL queries are themselves DataFrames and support all normal functions. sqlDF = spark.sql("SELECT key, value FROM src WHERE key < 10 ORDER BY key") - # The items in DaraFrames are of type Row, which allows you to access each column by ordinal. + # The items in DataFrames are of type Row, which allows you to access each column by ordinal. stringsDS = sqlDF.rdd.map(lambda row: "Key: %d, Value: %s" % (row.key, row.value)) for record in stringsDS.collect(): print(record) diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/hive/SparkHiveExample.scala b/examples/src/main/scala/org/apache/spark/examples/sql/hive/SparkHiveExample.scala index 3de26364b528..e5f75d53edc8 100644 --- a/examples/src/main/scala/org/apache/spark/examples/sql/hive/SparkHiveExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/sql/hive/SparkHiveExample.scala @@ -76,7 +76,7 @@ object SparkHiveExample { // The results of SQL queries are themselves DataFrames and support all normal functions. val sqlDF = sql("SELECT key, value FROM src WHERE key < 10 ORDER BY key") - // The items in DaraFrames are of type Row, which allows you to access each column by ordinal. + // The items in DataFrames are of type Row, which allows you to access each column by ordinal. val stringsDS = sqlDF.map { case Row(key: Int, value: String) => s"Key: $key, Value: $value" } From 626b4cafce7d2dca186144336939d4d993b6f878 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 6 Apr 2017 19:24:03 -0700 Subject: [PATCH 230/512] [SPARK-19495][SQL] Make SQLConf slightly more extensible - addendum ## What changes were proposed in this pull request? This is a tiny addendum to SPARK-19495 to remove the private visibility for copy, which is the only package private method in the entire file. ## How was this patch tested? N/A - no semantic change. Author: Reynold Xin Closes #17555 from rxin/SPARK-19495-2. --- .../src/main/scala/org/apache/spark/sql/internal/SQLConf.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index e685c2bed50a..640c0f189c23 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1153,7 +1153,7 @@ class SQLConf extends Serializable with Logging { } // For test only - private[spark] def copy(entries: (ConfigEntry[_], Any)*): SQLConf = { + def copy(entries: (ConfigEntry[_], Any)*): SQLConf = { val cloned = clone() entries.foreach { case (entry, value) => cloned.setConfString(entry.key, value.toString) From ad3cc1312db3b5667cea134940a09896a4609b74 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 7 Apr 2017 15:58:50 +0800 Subject: [PATCH 231/512] [SPARK-20245][SQL][MINOR] pass output to LogicalRelation directly ## What changes were proposed in this pull request? Currently `LogicalRelation` has a `expectedOutputAttributes` parameter, which makes it hard to reason about what the actual output is. Like other leaf nodes, `LogicalRelation` should also take `output` as a parameter, to simplify the logic ## How was this patch tested? existing tests Author: Wenchen Fan Closes #17552 from cloud-fan/minor. --- .../sql/catalyst/catalog/interface.scala | 8 ++-- .../datasources/DataSourceStrategy.scala | 15 +++---- .../datasources/LogicalRelation.scala | 39 +++++++------------ .../PruneFileSourcePartitions.scala | 4 +- .../spark/sql/sources/PathOptionSuite.scala | 19 ++++----- .../spark/sql/hive/HiveMetastoreCatalog.scala | 13 +++++-- .../spark/sql/hive/CachedTableSuite.scala | 4 +- .../PruneFileSourcePartitionsSuite.scala | 2 +- 8 files changed, 49 insertions(+), 55 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala index dc2e40424fd5..360e55d92282 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala @@ -27,7 +27,7 @@ import com.google.common.base.Objects import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.{FunctionIdentifier, InternalRow, TableIdentifier} import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, Cast, Literal} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference, Cast, Literal} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} import org.apache.spark.sql.catalyst.util.quoteIdentifier @@ -403,14 +403,14 @@ object CatalogTypes { */ case class CatalogRelation( tableMeta: CatalogTable, - dataCols: Seq[Attribute], - partitionCols: Seq[Attribute]) extends LeafNode with MultiInstanceRelation { + dataCols: Seq[AttributeReference], + partitionCols: Seq[AttributeReference]) extends LeafNode with MultiInstanceRelation { assert(tableMeta.identifier.database.isDefined) assert(tableMeta.partitionSchema.sameType(partitionCols.toStructType)) assert(tableMeta.dataSchema.sameType(dataCols.toStructType)) // The partition column should always appear after data columns. - override def output: Seq[Attribute] = dataCols ++ partitionCols + override def output: Seq[AttributeReference] = dataCols ++ partitionCols def isPartitioned: Boolean = partitionCols.nonEmpty diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index e5c7c383d708..2d83d512e702 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -231,16 +231,17 @@ class FindDataSourceTable(sparkSession: SparkSession) extends Rule[LogicalPlan] options = table.storage.properties ++ pathOption, catalogTable = Some(table)) - LogicalRelation( - dataSource.resolveRelation(checkFilesExist = false), - catalogTable = Some(table)) + LogicalRelation(dataSource.resolveRelation(checkFilesExist = false), table) } }).asInstanceOf[LogicalRelation] - // It's possible that the table schema is empty and need to be inferred at runtime. We should - // not specify expected outputs for this case. - val expectedOutputs = if (r.output.isEmpty) None else Some(r.output) - plan.copy(expectedOutputAttributes = expectedOutputs) + if (r.output.isEmpty) { + // It's possible that the table schema is empty and need to be inferred at runtime. For this + // case, we don't need to change the output of the cached plan. + plan + } else { + plan.copy(output = r.output) + } } override def apply(plan: LogicalPlan): LogicalPlan = plan transform { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala index 3b14b794fd08..421520396007 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.datasources import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.catalog.CatalogTable -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference} +import org.apache.spark.sql.catalyst.expressions.{AttributeMap, AttributeReference} import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.BaseRelation @@ -26,31 +26,13 @@ import org.apache.spark.util.Utils /** * Used to link a [[BaseRelation]] in to a logical query plan. - * - * Note that sometimes we need to use `LogicalRelation` to replace an existing leaf node without - * changing the output attributes' IDs. The `expectedOutputAttributes` parameter is used for - * this purpose. See https://issues.apache.org/jira/browse/SPARK-10741 for more details. */ case class LogicalRelation( relation: BaseRelation, - expectedOutputAttributes: Option[Seq[Attribute]] = None, - catalogTable: Option[CatalogTable] = None) + output: Seq[AttributeReference], + catalogTable: Option[CatalogTable]) extends LeafNode with MultiInstanceRelation { - override val output: Seq[AttributeReference] = { - val attrs = relation.schema.toAttributes - expectedOutputAttributes.map { expectedAttrs => - assert(expectedAttrs.length == attrs.length) - attrs.zip(expectedAttrs).map { - // We should respect the attribute names provided by base relation and only use the - // exprId in `expectedOutputAttributes`. - // The reason is that, some relations(like parquet) will reconcile attribute names to - // workaround case insensitivity issue. - case (attr, expected) => attr.withExprId(expected.exprId) - } - }.getOrElse(attrs) - } - // Logical Relations are distinct if they have different output for the sake of transformations. override def equals(other: Any): Boolean = other match { case l @ LogicalRelation(otherRelation, _, _) => relation == otherRelation && output == l.output @@ -87,11 +69,8 @@ case class LogicalRelation( * unique expression ids. We respect the `expectedOutputAttributes` and create * new instances of attributes in it. */ - override def newInstance(): this.type = { - LogicalRelation( - relation, - expectedOutputAttributes.map(_.map(_.newInstance())), - catalogTable).asInstanceOf[this.type] + override def newInstance(): LogicalRelation = { + this.copy(output = output.map(_.newInstance())) } override def refresh(): Unit = relation match { @@ -101,3 +80,11 @@ case class LogicalRelation( override def simpleString: String = s"Relation[${Utils.truncatedString(output, ",")}] $relation" } + +object LogicalRelation { + def apply(relation: BaseRelation): LogicalRelation = + LogicalRelation(relation, relation.schema.toAttributes, None) + + def apply(relation: BaseRelation, table: CatalogTable): LogicalRelation = + LogicalRelation(relation, relation.schema.toAttributes, Some(table)) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala index 8566a8061034..905b8683e10b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala @@ -59,9 +59,7 @@ private[sql] object PruneFileSourcePartitions extends Rule[LogicalPlan] { val prunedFileIndex = catalogFileIndex.filterPartitions(partitionKeyFilters.toSeq) val prunedFsRelation = fsRelation.copy(location = prunedFileIndex)(sparkSession) - val prunedLogicalRelation = logicalRelation.copy( - relation = prunedFsRelation, - expectedOutputAttributes = Some(logicalRelation.output)) + val prunedLogicalRelation = logicalRelation.copy(relation = prunedFsRelation) // Keep partition-pruning predicates so that they are visible in physical planning val filterExpression = filters.reduceLeft(And) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/PathOptionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/PathOptionSuite.scala index 60adee4599b0..6dd4847ead73 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/PathOptionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/PathOptionSuite.scala @@ -75,13 +75,13 @@ class PathOptionSuite extends DataSourceTest with SharedSQLContext { |USING ${classOf[TestOptionsSource].getCanonicalName} |OPTIONS (PATH '/tmp/path') """.stripMargin) - assert(getPathOption("src") == Some("file:/tmp/path")) + assert(getPathOption("src").map(makeQualifiedPath) == Some(makeQualifiedPath("/tmp/path"))) } // should exist even path option is not specified when creating table withTable("src") { sql(s"CREATE TABLE src(i int) USING ${classOf[TestOptionsSource].getCanonicalName}") - assert(getPathOption("src") == Some(CatalogUtils.URIToString(defaultTablePath("src")))) + assert(getPathOption("src").map(makeQualifiedPath) == Some(defaultTablePath("src"))) } } @@ -95,9 +95,9 @@ class PathOptionSuite extends DataSourceTest with SharedSQLContext { |OPTIONS (PATH '$p') |AS SELECT 1 """.stripMargin) - assert(CatalogUtils.stringToURI( - spark.table("src").schema.head.metadata.getString("path")) == - makeQualifiedPath(p.getAbsolutePath)) + assert( + spark.table("src").schema.head.metadata.getString("path") == + p.getAbsolutePath) } } @@ -109,8 +109,9 @@ class PathOptionSuite extends DataSourceTest with SharedSQLContext { |USING ${classOf[TestOptionsSource].getCanonicalName} |AS SELECT 1 """.stripMargin) - assert(spark.table("src").schema.head.metadata.getString("path") == - CatalogUtils.URIToString(defaultTablePath("src"))) + assert( + makeQualifiedPath(spark.table("src").schema.head.metadata.getString("path")) == + defaultTablePath("src")) } } @@ -122,13 +123,13 @@ class PathOptionSuite extends DataSourceTest with SharedSQLContext { |USING ${classOf[TestOptionsSource].getCanonicalName} |OPTIONS (PATH '/tmp/path')""".stripMargin) sql("ALTER TABLE src SET LOCATION '/tmp/path2'") - assert(getPathOption("src") == Some("/tmp/path2")) + assert(getPathOption("src").map(makeQualifiedPath) == Some(makeQualifiedPath("/tmp/path2"))) } withTable("src", "src2") { sql(s"CREATE TABLE src(i int) USING ${classOf[TestOptionsSource].getCanonicalName}") sql("ALTER TABLE src RENAME TO src2") - assert(getPathOption("src2") == Some(CatalogUtils.URIToString(defaultTablePath("src2")))) + assert(getPathOption("src2").map(makeQualifiedPath) == Some(defaultTablePath("src2"))) } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 10f432570e94..6b98066cb76c 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -175,7 +175,7 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log bucketSpec = None, fileFormat = fileFormat, options = options)(sparkSession = sparkSession) - val created = LogicalRelation(fsRelation, catalogTable = Some(updatedTable)) + val created = LogicalRelation(fsRelation, updatedTable) tableRelationCache.put(tableIdentifier, created) created } @@ -203,7 +203,7 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log bucketSpec = None, options = options, className = fileType).resolveRelation(), - catalogTable = Some(updatedTable)) + table = updatedTable) tableRelationCache.put(tableIdentifier, created) created @@ -212,7 +212,14 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log logicalRelation }) } - result.copy(expectedOutputAttributes = Some(relation.output)) + // The inferred schema may have different filed names as the table schema, we should respect + // it, but also respect the exprId in table relation output. + assert(result.output.length == relation.output.length && + result.output.zip(relation.output).forall { case (a1, a2) => a1.dataType == a2.dataType }) + val newOutput = result.output.zip(relation.output).map { + case (a1, a2) => a1.withExprId(a2.exprId) + } + result.copy(output = newOutput) } private def inferIfNeeded( diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala index 2b3f36064c1f..d3cbf898e243 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala @@ -329,7 +329,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with TestHiveSingleto fileFormat = new ParquetFileFormat(), options = Map.empty)(sparkSession = spark) - val plan = LogicalRelation(relation, catalogTable = Some(tableMeta)) + val plan = LogicalRelation(relation, tableMeta) spark.sharedState.cacheManager.cacheQuery(Dataset.ofRows(spark, plan)) assert(spark.sharedState.cacheManager.lookupCachedData(plan).isDefined) @@ -342,7 +342,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with TestHiveSingleto bucketSpec = None, fileFormat = new ParquetFileFormat(), options = Map.empty)(sparkSession = spark) - val samePlan = LogicalRelation(sameRelation, catalogTable = Some(tableMeta)) + val samePlan = LogicalRelation(sameRelation, tableMeta) assert(spark.sharedState.cacheManager.lookupCachedData(samePlan).isDefined) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneFileSourcePartitionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneFileSourcePartitionsSuite.scala index cd8f94b1cc4f..f818e2955546 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneFileSourcePartitionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneFileSourcePartitionsSuite.scala @@ -58,7 +58,7 @@ class PruneFileSourcePartitionsSuite extends QueryTest with SQLTestUtils with Te fileFormat = new ParquetFileFormat(), options = Map.empty)(sparkSession = spark) - val logicalRelation = LogicalRelation(relation, catalogTable = Some(tableMeta)) + val logicalRelation = LogicalRelation(relation, tableMeta) val query = Project(Seq('i, 'p), Filter('p === 1, logicalRelation)).analyze val optimized = Optimize.execute(query) From 1a52a62377a87cec493c8c6711bfd44e779c7973 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 7 Apr 2017 11:00:10 +0200 Subject: [PATCH 232/512] [SPARK-20076][ML][PYSPARK] Add Python interface for ml.stats.Correlation ## What changes were proposed in this pull request? The Dataframes-based support for the correlation statistics is added in #17108. This patch adds the Python interface for it. ## How was this patch tested? Python unit test. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Liang-Chi Hsieh Closes #17494 from viirya/correlation-python-api. --- .../apache/spark/ml/stat/Correlation.scala | 8 +-- python/pyspark/ml/stat.py | 61 +++++++++++++++++++ 2 files changed, 65 insertions(+), 4 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/stat/Correlation.scala b/mllib/src/main/scala/org/apache/spark/ml/stat/Correlation.scala index d3c84b77d26a..e185bc8a6faa 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/stat/Correlation.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/stat/Correlation.scala @@ -38,7 +38,7 @@ object Correlation { /** * :: Experimental :: - * Compute the correlation matrix for the input RDD of Vectors using the specified method. + * Compute the correlation matrix for the input Dataset of Vectors using the specified method. * Methods currently supported: `pearson` (default), `spearman`. * * @param dataset A dataset or a dataframe @@ -56,14 +56,14 @@ object Correlation { * Here is how to access the correlation coefficient: * {{{ * val data: Dataset[Vector] = ... - * val Row(coeff: Matrix) = Statistics.corr(data, "value").head + * val Row(coeff: Matrix) = Correlation.corr(data, "value").head * // coeff now contains the Pearson correlation matrix. * }}} * * @note For Spearman, a rank correlation, we need to create an RDD[Double] for each column * and sort it in order to retrieve the ranks and then join the columns back into an RDD[Vector], - * which is fairly costly. Cache the input RDD before calling corr with `method = "spearman"` to - * avoid recomputing the common lineage. + * which is fairly costly. Cache the input Dataset before calling corr with `method = "spearman"` + * to avoid recomputing the common lineage. */ @Since("2.2.0") def corr(dataset: Dataset[_], column: String, method: String): DataFrame = { diff --git a/python/pyspark/ml/stat.py b/python/pyspark/ml/stat.py index db043ff68fec..079b0833e1c6 100644 --- a/python/pyspark/ml/stat.py +++ b/python/pyspark/ml/stat.py @@ -71,6 +71,67 @@ def test(dataset, featuresCol, labelCol): return _java2py(sc, javaTestObj.test(*args)) +class Correlation(object): + """ + .. note:: Experimental + + Compute the correlation matrix for the input dataset of Vectors using the specified method. + Methods currently supported: `pearson` (default), `spearman`. + + .. note:: For Spearman, a rank correlation, we need to create an RDD[Double] for each column + and sort it in order to retrieve the ranks and then join the columns back into an RDD[Vector], + which is fairly costly. Cache the input Dataset before calling corr with `method = 'spearman'` + to avoid recomputing the common lineage. + + :param dataset: + A dataset or a dataframe. + :param column: + The name of the column of vectors for which the correlation coefficient needs + to be computed. This must be a column of the dataset, and it must contain + Vector objects. + :param method: + String specifying the method to use for computing correlation. + Supported: `pearson` (default), `spearman`. + :return: + A dataframe that contains the correlation matrix of the column of vectors. This + dataframe contains a single row and a single column of name + '$METHODNAME($COLUMN)'. + + >>> from pyspark.ml.linalg import Vectors + >>> from pyspark.ml.stat import Correlation + >>> dataset = [[Vectors.dense([1, 0, 0, -2])], + ... [Vectors.dense([4, 5, 0, 3])], + ... [Vectors.dense([6, 7, 0, 8])], + ... [Vectors.dense([9, 0, 0, 1])]] + >>> dataset = spark.createDataFrame(dataset, ['features']) + >>> pearsonCorr = Correlation.corr(dataset, 'features', 'pearson').collect()[0][0] + >>> print(str(pearsonCorr).replace('nan', 'NaN')) + DenseMatrix([[ 1. , 0.0556..., NaN, 0.4004...], + [ 0.0556..., 1. , NaN, 0.9135...], + [ NaN, NaN, 1. , NaN], + [ 0.4004..., 0.9135..., NaN, 1. ]]) + >>> spearmanCorr = Correlation.corr(dataset, 'features', method='spearman').collect()[0][0] + >>> print(str(spearmanCorr).replace('nan', 'NaN')) + DenseMatrix([[ 1. , 0.1054..., NaN, 0.4 ], + [ 0.1054..., 1. , NaN, 0.9486... ], + [ NaN, NaN, 1. , NaN], + [ 0.4 , 0.9486... , NaN, 1. ]]) + + .. versionadded:: 2.2.0 + + """ + @staticmethod + @since("2.2.0") + def corr(dataset, column, method="pearson"): + """ + Compute the correlation matrix with specified method using dataset. + """ + sc = SparkContext._active_spark_context + javaCorrObj = _jvm().org.apache.spark.ml.stat.Correlation + args = [_py2java(sc, arg) for arg in (dataset, column, method)] + return _java2py(sc, javaCorrObj.corr(*args)) + + if __name__ == "__main__": import doctest import pyspark.ml.stat From 9e0893b53d68f777c1f3fb0a67820424a9c253ab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=83=AD=E5=B0=8F=E9=BE=99=2010207633?= Date: Fri, 7 Apr 2017 13:03:07 +0100 Subject: [PATCH 233/512] [SPARK-20218][DOC][APP-ID] applications//stages' in REST API,add description. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? 1. '/applications/[app-id]/stages' in rest api.status should add description '?status=[active|complete|pending|failed] list only stages in the state.' Now the lack of this description, resulting in the use of this api do not know the use of the status through the brush stage list. 2.'/applications/[app-id]/stages/[stage-id]' in REST API,remove redundant description ‘?status=[active|complete|pending|failed] list only stages in the state.’. Because only one stage is determined based on stage-id. code: GET def stageList(QueryParam("status") statuses: JList[StageStatus]): Seq[StageData] = { val listener = ui.jobProgressListener val stageAndStatus = AllStagesResource.stagesAndStatus(ui) val adjStatuses = { if (statuses.isEmpty()) { Arrays.asList(StageStatus.values(): _*) } else { statuses } }; ## How was this patch tested? manual tests Please review http://spark.apache.org/contributing.html before opening a pull request. Author: 郭小龙 10207633 Closes #17534 from guoxiaolongzte/SPARK-20218. --- docs/monitoring.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/monitoring.md b/docs/monitoring.md index 4d0617d253b8..da954385dc45 100644 --- a/docs/monitoring.md +++ b/docs/monitoring.md @@ -299,12 +299,12 @@ can be identified by their `[attempt-id]`. In the API listed below, when running /applications/[app-id]/stages A list of all stages for a given application. +
    ?status=[active|complete|pending|failed] list only stages in the state. /applications/[app-id]/stages/[stage-id] A list of all attempts for the given stage. -
    ?status=[active|complete|pending|failed] list only stages in the state. From 870b9d9aa00c260b532c78088e4a0384f7f1fa8a Mon Sep 17 00:00:00 2001 From: actuaryzhang Date: Fri, 7 Apr 2017 10:57:12 -0700 Subject: [PATCH 234/512] [SPARK-20026][DOC][SPARKR] Add Tweedie example for SparkR in programming guide ## What changes were proposed in this pull request? Add Tweedie example for SparkR in programming guide. The doc was already updated in #17103. Author: actuaryzhang Closes #17553 from actuaryzhang/programGuide. --- examples/src/main/r/ml/glm.R | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/examples/src/main/r/ml/glm.R b/examples/src/main/r/ml/glm.R index ee13910382c5..23141b57df14 100644 --- a/examples/src/main/r/ml/glm.R +++ b/examples/src/main/r/ml/glm.R @@ -56,6 +56,15 @@ summary(binomialGLM) # Prediction binomialPredictions <- predict(binomialGLM, binomialTestDF) head(binomialPredictions) + +# Fit a generalized linear model of family "tweedie" with spark.glm +training3 <- read.df("data/mllib/sample_multiclass_classification_data.txt", source = "libsvm") +tweedieDF <- transform(training3, label = training3$label * exp(randn(10))) +tweedieGLM <- spark.glm(tweedieDF, label ~ features, family = "tweedie", + var.power = 1.2, link.power = 0) + +# Model summary +summary(tweedieGLM) # $example off$ sparkR.session.stop() From 8feb799af0bb67618310947342e3e4d2a77aae13 Mon Sep 17 00:00:00 2001 From: Felix Cheung Date: Fri, 7 Apr 2017 11:17:49 -0700 Subject: [PATCH 235/512] [SPARK-20197][SPARKR] CRAN check fail with package installation ## What changes were proposed in this pull request? Test failed because SPARK_HOME is not set before Spark is installed. Author: Felix Cheung Closes #17516 from felixcheung/rdircheckincran. --- R/pkg/tests/run-all.R | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/R/pkg/tests/run-all.R b/R/pkg/tests/run-all.R index cefaadda6e21..29812f872c78 100644 --- a/R/pkg/tests/run-all.R +++ b/R/pkg/tests/run-all.R @@ -22,12 +22,13 @@ library(SparkR) options("warn" = 2) # Setup global test environment +# Install Spark first to set SPARK_HOME +install.spark() + sparkRDir <- file.path(Sys.getenv("SPARK_HOME"), "R") sparkRFilesBefore <- list.files(path = sparkRDir, all.files = TRUE) sparkRWhitelistSQLDirs <- c("spark-warehouse", "metastore_db") invisible(lapply(sparkRWhitelistSQLDirs, function(x) { unlink(file.path(sparkRDir, x), recursive = TRUE, force = TRUE)})) -install.spark() - test_package("SparkR") From 1ad73f0a21d8007d8466ef8756f751c0ab6a9d1f Mon Sep 17 00:00:00 2001 From: actuaryzhang Date: Fri, 7 Apr 2017 12:29:45 -0700 Subject: [PATCH 236/512] [SPARK-20258][DOC][SPARKR] Fix SparkR logistic regression example in programming guide (did not converge) ## What changes were proposed in this pull request? SparkR logistic regression example did not converge in programming guide (for IRWLS). All estimates are essentially zero: ``` training2 <- read.df("data/mllib/sample_binary_classification_data.txt", source = "libsvm") df_list2 <- randomSplit(training2, c(7,3), 2) binomialDF <- df_list2[[1]] binomialTestDF <- df_list2[[2]] binomialGLM <- spark.glm(binomialDF, label ~ features, family = "binomial") 17/04/07 11:42:03 WARN WeightedLeastSquares: Cholesky solver failed due to singular covariance matrix. Retrying with Quasi-Newton solver. > summary(binomialGLM) Coefficients: Estimate (Intercept) 9.0255e+00 features_0 0.0000e+00 features_1 0.0000e+00 features_2 0.0000e+00 features_3 0.0000e+00 features_4 0.0000e+00 features_5 0.0000e+00 features_6 0.0000e+00 features_7 0.0000e+00 ``` Author: actuaryzhang Closes #17571 from actuaryzhang/programGuide2. --- examples/src/main/r/ml/glm.R | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/examples/src/main/r/ml/glm.R b/examples/src/main/r/ml/glm.R index 23141b57df14..68787f9aa9dc 100644 --- a/examples/src/main/r/ml/glm.R +++ b/examples/src/main/r/ml/glm.R @@ -27,7 +27,7 @@ sparkR.session(appName = "SparkR-ML-glm-example") # $example on$ training <- read.df("data/mllib/sample_multiclass_classification_data.txt", source = "libsvm") # Fit a generalized linear model of family "gaussian" with spark.glm -df_list <- randomSplit(training, c(7,3), 2) +df_list <- randomSplit(training, c(7, 3), 2) gaussianDF <- df_list[[1]] gaussianTestDF <- df_list[[2]] gaussianGLM <- spark.glm(gaussianDF, label ~ features, family = "gaussian") @@ -44,8 +44,9 @@ gaussianGLM2 <- glm(label ~ features, gaussianDF, family = "gaussian") summary(gaussianGLM2) # Fit a generalized linear model of family "binomial" with spark.glm -training2 <- read.df("data/mllib/sample_binary_classification_data.txt", source = "libsvm") -df_list2 <- randomSplit(training2, c(7,3), 2) +training2 <- read.df("data/mllib/sample_multiclass_classification_data.txt", source = "libsvm") +training2 <- transform(training2, label = cast(training2$label > 1, "integer")) +df_list2 <- randomSplit(training2, c(7, 3), 2) binomialDF <- df_list2[[1]] binomialTestDF <- df_list2[[2]] binomialGLM <- spark.glm(binomialDF, label ~ features, family = "binomial") From 589f3edb82e970b6df9121861ed0c6b4a6d02cb6 Mon Sep 17 00:00:00 2001 From: Adrian Ionescu Date: Fri, 7 Apr 2017 14:00:23 -0700 Subject: [PATCH 237/512] [SPARK-20255] Move listLeafFiles() to InMemoryFileIndex ## What changes were proposed in this pull request Trying to get a grip on the `FileIndex` hierarchy, I was confused by the following inconsistency: On the one hand, `PartitioningAwareFileIndex` defines `leafFiles` and `leafDirToChildrenFiles` as abstract, but on the other it fully implements `listLeafFiles` which does all the listing of files. However, the latter is only used by `InMemoryFileIndex`. I'm hereby proposing to move this method (and all its dependencies) to the implementation class that actually uses it, and thus unclutter the `PartitioningAwareFileIndex` interface. ## How was this patch tested? `./build/sbt sql/test` Author: Adrian Ionescu Closes #17570 from adrian-ionescu/list-leaf-files. --- .../datasources/InMemoryFileIndex.scala | 226 ++++++++++++++++++ .../PartitioningAwareFileIndex.scala | 223 +---------------- .../datasources/FileIndexSuite.scala | 18 +- 3 files changed, 236 insertions(+), 231 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala index ee4d0863d977..11605dd28056 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala @@ -17,12 +17,19 @@ package org.apache.spark.sql.execution.datasources +import java.io.FileNotFoundException + import scala.collection.mutable +import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs._ +import org.apache.hadoop.mapred.{FileInputFormat, JobConf} +import org.apache.spark.internal.Logging +import org.apache.spark.metrics.source.HiveCatalogMetrics import org.apache.spark.sql.SparkSession import org.apache.spark.sql.types.StructType +import org.apache.spark.util.SerializableConfiguration /** @@ -84,4 +91,223 @@ class InMemoryFileIndex( } override def hashCode(): Int = rootPaths.toSet.hashCode() + + /** + * List leaf files of given paths. This method will submit a Spark job to do parallel + * listing whenever there is a path having more files than the parallel partition discovery + * discovery threshold. + * + * This is publicly visible for testing. + */ + def listLeafFiles(paths: Seq[Path]): mutable.LinkedHashSet[FileStatus] = { + val output = mutable.LinkedHashSet[FileStatus]() + val pathsToFetch = mutable.ArrayBuffer[Path]() + for (path <- paths) { + fileStatusCache.getLeafFiles(path) match { + case Some(files) => + HiveCatalogMetrics.incrementFileCacheHits(files.length) + output ++= files + case None => + pathsToFetch += path + } + } + val filter = FileInputFormat.getInputPathFilter(new JobConf(hadoopConf, this.getClass)) + val discovered = InMemoryFileIndex.bulkListLeafFiles( + pathsToFetch, hadoopConf, filter, sparkSession) + discovered.foreach { case (path, leafFiles) => + HiveCatalogMetrics.incrementFilesDiscovered(leafFiles.size) + fileStatusCache.putLeafFiles(path, leafFiles.toArray) + output ++= leafFiles + } + output + } +} + +object InMemoryFileIndex extends Logging { + + /** A serializable variant of HDFS's BlockLocation. */ + private case class SerializableBlockLocation( + names: Array[String], + hosts: Array[String], + offset: Long, + length: Long) + + /** A serializable variant of HDFS's FileStatus. */ + private case class SerializableFileStatus( + path: String, + length: Long, + isDir: Boolean, + blockReplication: Short, + blockSize: Long, + modificationTime: Long, + accessTime: Long, + blockLocations: Array[SerializableBlockLocation]) + + /** + * Lists a collection of paths recursively. Picks the listing strategy adaptively depending + * on the number of paths to list. + * + * This may only be called on the driver. + * + * @return for each input path, the set of discovered files for the path + */ + private def bulkListLeafFiles( + paths: Seq[Path], + hadoopConf: Configuration, + filter: PathFilter, + sparkSession: SparkSession): Seq[(Path, Seq[FileStatus])] = { + + // Short-circuits parallel listing when serial listing is likely to be faster. + if (paths.size <= sparkSession.sessionState.conf.parallelPartitionDiscoveryThreshold) { + return paths.map { path => + (path, listLeafFiles(path, hadoopConf, filter, Some(sparkSession))) + } + } + + logInfo(s"Listing leaf files and directories in parallel under: ${paths.mkString(", ")}") + HiveCatalogMetrics.incrementParallelListingJobCount(1) + + val sparkContext = sparkSession.sparkContext + val serializableConfiguration = new SerializableConfiguration(hadoopConf) + val serializedPaths = paths.map(_.toString) + val parallelPartitionDiscoveryParallelism = + sparkSession.sessionState.conf.parallelPartitionDiscoveryParallelism + + // Set the number of parallelism to prevent following file listing from generating many tasks + // in case of large #defaultParallelism. + val numParallelism = Math.min(paths.size, parallelPartitionDiscoveryParallelism) + + val statusMap = sparkContext + .parallelize(serializedPaths, numParallelism) + .mapPartitions { pathStrings => + val hadoopConf = serializableConfiguration.value + pathStrings.map(new Path(_)).toSeq.map { path => + (path, listLeafFiles(path, hadoopConf, filter, None)) + }.iterator + }.map { case (path, statuses) => + val serializableStatuses = statuses.map { status => + // Turn FileStatus into SerializableFileStatus so we can send it back to the driver + val blockLocations = status match { + case f: LocatedFileStatus => + f.getBlockLocations.map { loc => + SerializableBlockLocation( + loc.getNames, + loc.getHosts, + loc.getOffset, + loc.getLength) + } + + case _ => + Array.empty[SerializableBlockLocation] + } + + SerializableFileStatus( + status.getPath.toString, + status.getLen, + status.isDirectory, + status.getReplication, + status.getBlockSize, + status.getModificationTime, + status.getAccessTime, + blockLocations) + } + (path.toString, serializableStatuses) + }.collect() + + // turn SerializableFileStatus back to Status + statusMap.map { case (path, serializableStatuses) => + val statuses = serializableStatuses.map { f => + val blockLocations = f.blockLocations.map { loc => + new BlockLocation(loc.names, loc.hosts, loc.offset, loc.length) + } + new LocatedFileStatus( + new FileStatus( + f.length, f.isDir, f.blockReplication, f.blockSize, f.modificationTime, + new Path(f.path)), + blockLocations) + } + (new Path(path), statuses) + } + } + + /** + * Lists a single filesystem path recursively. If a SparkSession object is specified, this + * function may launch Spark jobs to parallelize listing. + * + * If sessionOpt is None, this may be called on executors. + * + * @return all children of path that match the specified filter. + */ + private def listLeafFiles( + path: Path, + hadoopConf: Configuration, + filter: PathFilter, + sessionOpt: Option[SparkSession]): Seq[FileStatus] = { + logTrace(s"Listing $path") + val fs = path.getFileSystem(hadoopConf) + val name = path.getName.toLowerCase + + // [SPARK-17599] Prevent InMemoryFileIndex from failing if path doesn't exist + // Note that statuses only include FileStatus for the files and dirs directly under path, + // and does not include anything else recursively. + val statuses = try fs.listStatus(path) catch { + case _: FileNotFoundException => + logWarning(s"The directory $path was not found. Was it deleted very recently?") + Array.empty[FileStatus] + } + + val filteredStatuses = statuses.filterNot(status => shouldFilterOut(status.getPath.getName)) + + val allLeafStatuses = { + val (dirs, topLevelFiles) = filteredStatuses.partition(_.isDirectory) + val nestedFiles: Seq[FileStatus] = sessionOpt match { + case Some(session) => + bulkListLeafFiles(dirs.map(_.getPath), hadoopConf, filter, session).flatMap(_._2) + case _ => + dirs.flatMap(dir => listLeafFiles(dir.getPath, hadoopConf, filter, sessionOpt)) + } + val allFiles = topLevelFiles ++ nestedFiles + if (filter != null) allFiles.filter(f => filter.accept(f.getPath)) else allFiles + } + + allLeafStatuses.filterNot(status => shouldFilterOut(status.getPath.getName)).map { + case f: LocatedFileStatus => + f + + // NOTE: + // + // - Although S3/S3A/S3N file system can be quite slow for remote file metadata + // operations, calling `getFileBlockLocations` does no harm here since these file system + // implementations don't actually issue RPC for this method. + // + // - Here we are calling `getFileBlockLocations` in a sequential manner, but it should not + // be a big deal since we always use to `listLeafFilesInParallel` when the number of + // paths exceeds threshold. + case f => + // The other constructor of LocatedFileStatus will call FileStatus.getPermission(), + // which is very slow on some file system (RawLocalFileSystem, which is launch a + // subprocess and parse the stdout). + val locations = fs.getFileBlockLocations(f, 0, f.getLen) + val lfs = new LocatedFileStatus(f.getLen, f.isDirectory, f.getReplication, f.getBlockSize, + f.getModificationTime, 0, null, null, null, null, f.getPath, locations) + if (f.isSymlink) { + lfs.setSymlink(f.getSymlink) + } + lfs + } + } + + /** Checks if we should filter out this path name. */ + def shouldFilterOut(pathName: String): Boolean = { + // We filter follow paths: + // 1. everything that starts with _ and ., except _common_metadata and _metadata + // because Parquet needs to find those metadata files from leaf files returned by this method. + // We should refactor this logic to not mix metadata files with data files. + // 2. everything that ends with `._COPYING_`, because this is a intermediate state of file. we + // should skip this file in case of double reading. + val exclude = (pathName.startsWith("_") && !pathName.contains("=")) || + pathName.startsWith(".") || pathName.endsWith("._COPYING_") + val include = pathName.startsWith("_common_metadata") || pathName.startsWith("_metadata") + exclude && !include + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala index 71500a010581..ffd7f6c750f8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala @@ -17,22 +17,17 @@ package org.apache.spark.sql.execution.datasources -import java.io.FileNotFoundException - import scala.collection.mutable import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs._ -import org.apache.hadoop.mapred.{FileInputFormat, JobConf} import org.apache.spark.internal.Logging -import org.apache.spark.metrics.source.HiveCatalogMetrics import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.{expressions, InternalRow} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} import org.apache.spark.sql.types.{StringType, StructType} -import org.apache.spark.util.SerializableConfiguration /** * An abstract class that represents [[FileIndex]]s that are aware of partitioned tables. @@ -241,224 +236,8 @@ abstract class PartitioningAwareFileIndex( val name = path.getName !((name.startsWith("_") && !name.contains("=")) || name.startsWith(".")) } - - /** - * List leaf files of given paths. This method will submit a Spark job to do parallel - * listing whenever there is a path having more files than the parallel partition discovery - * discovery threshold. - * - * This is publicly visible for testing. - */ - def listLeafFiles(paths: Seq[Path]): mutable.LinkedHashSet[FileStatus] = { - val output = mutable.LinkedHashSet[FileStatus]() - val pathsToFetch = mutable.ArrayBuffer[Path]() - for (path <- paths) { - fileStatusCache.getLeafFiles(path) match { - case Some(files) => - HiveCatalogMetrics.incrementFileCacheHits(files.length) - output ++= files - case None => - pathsToFetch += path - } - } - val filter = FileInputFormat.getInputPathFilter(new JobConf(hadoopConf, this.getClass)) - val discovered = PartitioningAwareFileIndex.bulkListLeafFiles( - pathsToFetch, hadoopConf, filter, sparkSession) - discovered.foreach { case (path, leafFiles) => - HiveCatalogMetrics.incrementFilesDiscovered(leafFiles.size) - fileStatusCache.putLeafFiles(path, leafFiles.toArray) - output ++= leafFiles - } - output - } } -object PartitioningAwareFileIndex extends Logging { +object PartitioningAwareFileIndex { val BASE_PATH_PARAM = "basePath" - - /** A serializable variant of HDFS's BlockLocation. */ - private case class SerializableBlockLocation( - names: Array[String], - hosts: Array[String], - offset: Long, - length: Long) - - /** A serializable variant of HDFS's FileStatus. */ - private case class SerializableFileStatus( - path: String, - length: Long, - isDir: Boolean, - blockReplication: Short, - blockSize: Long, - modificationTime: Long, - accessTime: Long, - blockLocations: Array[SerializableBlockLocation]) - - /** - * Lists a collection of paths recursively. Picks the listing strategy adaptively depending - * on the number of paths to list. - * - * This may only be called on the driver. - * - * @return for each input path, the set of discovered files for the path - */ - private def bulkListLeafFiles( - paths: Seq[Path], - hadoopConf: Configuration, - filter: PathFilter, - sparkSession: SparkSession): Seq[(Path, Seq[FileStatus])] = { - - // Short-circuits parallel listing when serial listing is likely to be faster. - if (paths.size <= sparkSession.sessionState.conf.parallelPartitionDiscoveryThreshold) { - return paths.map { path => - (path, listLeafFiles(path, hadoopConf, filter, Some(sparkSession))) - } - } - - logInfo(s"Listing leaf files and directories in parallel under: ${paths.mkString(", ")}") - HiveCatalogMetrics.incrementParallelListingJobCount(1) - - val sparkContext = sparkSession.sparkContext - val serializableConfiguration = new SerializableConfiguration(hadoopConf) - val serializedPaths = paths.map(_.toString) - val parallelPartitionDiscoveryParallelism = - sparkSession.sessionState.conf.parallelPartitionDiscoveryParallelism - - // Set the number of parallelism to prevent following file listing from generating many tasks - // in case of large #defaultParallelism. - val numParallelism = Math.min(paths.size, parallelPartitionDiscoveryParallelism) - - val statusMap = sparkContext - .parallelize(serializedPaths, numParallelism) - .mapPartitions { pathStrings => - val hadoopConf = serializableConfiguration.value - pathStrings.map(new Path(_)).toSeq.map { path => - (path, listLeafFiles(path, hadoopConf, filter, None)) - }.iterator - }.map { case (path, statuses) => - val serializableStatuses = statuses.map { status => - // Turn FileStatus into SerializableFileStatus so we can send it back to the driver - val blockLocations = status match { - case f: LocatedFileStatus => - f.getBlockLocations.map { loc => - SerializableBlockLocation( - loc.getNames, - loc.getHosts, - loc.getOffset, - loc.getLength) - } - - case _ => - Array.empty[SerializableBlockLocation] - } - - SerializableFileStatus( - status.getPath.toString, - status.getLen, - status.isDirectory, - status.getReplication, - status.getBlockSize, - status.getModificationTime, - status.getAccessTime, - blockLocations) - } - (path.toString, serializableStatuses) - }.collect() - - // turn SerializableFileStatus back to Status - statusMap.map { case (path, serializableStatuses) => - val statuses = serializableStatuses.map { f => - val blockLocations = f.blockLocations.map { loc => - new BlockLocation(loc.names, loc.hosts, loc.offset, loc.length) - } - new LocatedFileStatus( - new FileStatus( - f.length, f.isDir, f.blockReplication, f.blockSize, f.modificationTime, - new Path(f.path)), - blockLocations) - } - (new Path(path), statuses) - } - } - - /** - * Lists a single filesystem path recursively. If a SparkSession object is specified, this - * function may launch Spark jobs to parallelize listing. - * - * If sessionOpt is None, this may be called on executors. - * - * @return all children of path that match the specified filter. - */ - private def listLeafFiles( - path: Path, - hadoopConf: Configuration, - filter: PathFilter, - sessionOpt: Option[SparkSession]): Seq[FileStatus] = { - logTrace(s"Listing $path") - val fs = path.getFileSystem(hadoopConf) - val name = path.getName.toLowerCase - - // [SPARK-17599] Prevent InMemoryFileIndex from failing if path doesn't exist - // Note that statuses only include FileStatus for the files and dirs directly under path, - // and does not include anything else recursively. - val statuses = try fs.listStatus(path) catch { - case _: FileNotFoundException => - logWarning(s"The directory $path was not found. Was it deleted very recently?") - Array.empty[FileStatus] - } - - val filteredStatuses = statuses.filterNot(status => shouldFilterOut(status.getPath.getName)) - - val allLeafStatuses = { - val (dirs, topLevelFiles) = filteredStatuses.partition(_.isDirectory) - val nestedFiles: Seq[FileStatus] = sessionOpt match { - case Some(session) => - bulkListLeafFiles(dirs.map(_.getPath), hadoopConf, filter, session).flatMap(_._2) - case _ => - dirs.flatMap(dir => listLeafFiles(dir.getPath, hadoopConf, filter, sessionOpt)) - } - val allFiles = topLevelFiles ++ nestedFiles - if (filter != null) allFiles.filter(f => filter.accept(f.getPath)) else allFiles - } - - allLeafStatuses.filterNot(status => shouldFilterOut(status.getPath.getName)).map { - case f: LocatedFileStatus => - f - - // NOTE: - // - // - Although S3/S3A/S3N file system can be quite slow for remote file metadata - // operations, calling `getFileBlockLocations` does no harm here since these file system - // implementations don't actually issue RPC for this method. - // - // - Here we are calling `getFileBlockLocations` in a sequential manner, but it should not - // be a big deal since we always use to `listLeafFilesInParallel` when the number of - // paths exceeds threshold. - case f => - // The other constructor of LocatedFileStatus will call FileStatus.getPermission(), - // which is very slow on some file system (RawLocalFileSystem, which is launch a - // subprocess and parse the stdout). - val locations = fs.getFileBlockLocations(f, 0, f.getLen) - val lfs = new LocatedFileStatus(f.getLen, f.isDirectory, f.getReplication, f.getBlockSize, - f.getModificationTime, 0, null, null, null, null, f.getPath, locations) - if (f.isSymlink) { - lfs.setSymlink(f.getSymlink) - } - lfs - } - } - - /** Checks if we should filter out this path name. */ - def shouldFilterOut(pathName: String): Boolean = { - // We filter follow paths: - // 1. everything that starts with _ and ., except _common_metadata and _metadata - // because Parquet needs to find those metadata files from leaf files returned by this method. - // We should refactor this logic to not mix metadata files with data files. - // 2. everything that ends with `._COPYING_`, because this is a intermediate state of file. we - // should skip this file in case of double reading. - val exclude = (pathName.startsWith("_") && !pathName.contains("=")) || - pathName.startsWith(".") || pathName.endsWith("._COPYING_") - val include = pathName.startsWith("_common_metadata") || pathName.startsWith("_metadata") - exclude && !include - } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileIndexSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileIndexSuite.scala index 7ea406492757..00f5d5db8f5f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileIndexSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileIndexSuite.scala @@ -135,15 +135,15 @@ class FileIndexSuite extends SharedSQLContext { } } - test("PartitioningAwareFileIndex - file filtering") { - assert(!PartitioningAwareFileIndex.shouldFilterOut("abcd")) - assert(PartitioningAwareFileIndex.shouldFilterOut(".ab")) - assert(PartitioningAwareFileIndex.shouldFilterOut("_cd")) - assert(!PartitioningAwareFileIndex.shouldFilterOut("_metadata")) - assert(!PartitioningAwareFileIndex.shouldFilterOut("_common_metadata")) - assert(PartitioningAwareFileIndex.shouldFilterOut("_ab_metadata")) - assert(PartitioningAwareFileIndex.shouldFilterOut("_cd_common_metadata")) - assert(PartitioningAwareFileIndex.shouldFilterOut("a._COPYING_")) + test("InMemoryFileIndex - file filtering") { + assert(!InMemoryFileIndex.shouldFilterOut("abcd")) + assert(InMemoryFileIndex.shouldFilterOut(".ab")) + assert(InMemoryFileIndex.shouldFilterOut("_cd")) + assert(!InMemoryFileIndex.shouldFilterOut("_metadata")) + assert(!InMemoryFileIndex.shouldFilterOut("_common_metadata")) + assert(InMemoryFileIndex.shouldFilterOut("_ab_metadata")) + assert(InMemoryFileIndex.shouldFilterOut("_cd_common_metadata")) + assert(InMemoryFileIndex.shouldFilterOut("a._COPYING_")) } test("SPARK-17613 - PartitioningAwareFileIndex: base path w/o '/' at end") { From 7577e9c356b580d744e1fc27c645fce41bdf9cf0 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 7 Apr 2017 20:54:18 -0700 Subject: [PATCH 238/512] [SPARK-20246][SQL] should not push predicate down through aggregate with non-deterministic expressions ## What changes were proposed in this pull request? Similar to `Project`, when `Aggregate` has non-deterministic expressions, we should not push predicate down through it, as it will change the number of input rows and thus change the evaluation result of non-deterministic expressions in `Aggregate`. ## How was this patch tested? new regression test Author: Wenchen Fan Closes #17562 from cloud-fan/filter. --- .../sql/catalyst/optimizer/Optimizer.scala | 60 ++++++++++--------- .../optimizer/FilterPushdownSuite.scala | 41 +++++++++++-- 2 files changed, 68 insertions(+), 33 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 577112779eea..d221b0611a89 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -755,7 +755,8 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper { // implies that, for a given input row, the output are determined by the expression's initial // state and all the input rows processed before. In another word, the order of input rows // matters for non-deterministic expressions, while pushing down predicates changes the order. - case filter @ Filter(condition, project @ Project(fields, grandChild)) + // This also applies to Aggregate. + case Filter(condition, project @ Project(fields, grandChild)) if fields.forall(_.deterministic) && canPushThroughCondition(grandChild, condition) => // Create a map of Aliases to their values from the child projection. @@ -766,33 +767,8 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper { project.copy(child = Filter(replaceAlias(condition, aliasMap), grandChild)) - // Push [[Filter]] operators through [[Window]] operators. Parts of the predicate that can be - // pushed beneath must satisfy the following conditions: - // 1. All the expressions are part of window partitioning key. The expressions can be compound. - // 2. Deterministic. - // 3. Placed before any non-deterministic predicates. - case filter @ Filter(condition, w: Window) - if w.partitionSpec.forall(_.isInstanceOf[AttributeReference]) => - val partitionAttrs = AttributeSet(w.partitionSpec.flatMap(_.references)) - - val (candidates, containingNonDeterministic) = - splitConjunctivePredicates(condition).span(_.deterministic) - - val (pushDown, rest) = candidates.partition { cond => - cond.references.subsetOf(partitionAttrs) - } - - val stayUp = rest ++ containingNonDeterministic - - if (pushDown.nonEmpty) { - val pushDownPredicate = pushDown.reduce(And) - val newWindow = w.copy(child = Filter(pushDownPredicate, w.child)) - if (stayUp.isEmpty) newWindow else Filter(stayUp.reduce(And), newWindow) - } else { - filter - } - - case filter @ Filter(condition, aggregate: Aggregate) => + case filter @ Filter(condition, aggregate: Aggregate) + if aggregate.aggregateExpressions.forall(_.deterministic) => // Find all the aliased expressions in the aggregate list that don't include any actual // AggregateExpression, and create a map from the alias to the expression val aliasMap = AttributeMap(aggregate.aggregateExpressions.collect { @@ -823,6 +799,32 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper { filter } + // Push [[Filter]] operators through [[Window]] operators. Parts of the predicate that can be + // pushed beneath must satisfy the following conditions: + // 1. All the expressions are part of window partitioning key. The expressions can be compound. + // 2. Deterministic. + // 3. Placed before any non-deterministic predicates. + case filter @ Filter(condition, w: Window) + if w.partitionSpec.forall(_.isInstanceOf[AttributeReference]) => + val partitionAttrs = AttributeSet(w.partitionSpec.flatMap(_.references)) + + val (candidates, containingNonDeterministic) = + splitConjunctivePredicates(condition).span(_.deterministic) + + val (pushDown, rest) = candidates.partition { cond => + cond.references.subsetOf(partitionAttrs) + } + + val stayUp = rest ++ containingNonDeterministic + + if (pushDown.nonEmpty) { + val pushDownPredicate = pushDown.reduce(And) + val newWindow = w.copy(child = Filter(pushDownPredicate, w.child)) + if (stayUp.isEmpty) newWindow else Filter(stayUp.reduce(And), newWindow) + } else { + filter + } + case filter @ Filter(condition, union: Union) => // Union could change the rows, so non-deterministic predicate can't be pushed down val (pushDown, stayUp) = splitConjunctivePredicates(condition).span(_.deterministic) @@ -848,7 +850,7 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper { filter } - case filter @ Filter(condition, u: UnaryNode) + case filter @ Filter(_, u: UnaryNode) if canPushThrough(u) && u.expressions.forall(_.deterministic) => pushDownPredicate(filter, u.child) { predicate => u.withNewChildren(Seq(Filter(predicate, u.child))) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index d846786473eb..ccd0b7c5d7f7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -134,15 +134,20 @@ class FilterPushdownSuite extends PlanTest { comparePlans(optimized, correctAnswer) } - test("nondeterministic: can't push down filter with nondeterministic condition through project") { + test("nondeterministic: can always push down filter through project with deterministic field") { val originalQuery = testRelation - .select(Rand(10).as('rand), 'a) - .where('rand > 5 || 'a > 5) + .select('a) + .where(Rand(10) > 5 || 'a > 5) .analyze val optimized = Optimize.execute(originalQuery) - comparePlans(optimized, originalQuery) + val correctAnswer = testRelation + .where(Rand(10) > 5 || 'a > 5) + .select('a) + .analyze + + comparePlans(optimized, correctAnswer) } test("nondeterministic: can't push down filter through project with nondeterministic field") { @@ -156,6 +161,34 @@ class FilterPushdownSuite extends PlanTest { comparePlans(optimized, originalQuery) } + test("nondeterministic: can't push down filter through aggregate with nondeterministic field") { + val originalQuery = testRelation + .groupBy('a)('a, Rand(10).as('rand)) + .where('a > 5) + .analyze + + val optimized = Optimize.execute(originalQuery) + + comparePlans(optimized, originalQuery) + } + + test("nondeterministic: push down part of filter through aggregate with deterministic field") { + val originalQuery = testRelation + .groupBy('a)('a) + .where('a > 5 && Rand(10) > 5) + .analyze + + val optimized = Optimize.execute(originalQuery.analyze) + + val correctAnswer = testRelation + .where('a > 5) + .groupBy('a)('a) + .where(Rand(10) > 5) + .analyze + + comparePlans(optimized, correctAnswer) + } + test("filters: combines filters") { val originalQuery = testRelation .select('a) From e1afc4dcca8ba517f48200c0ecde1152505e41ec Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Fri, 7 Apr 2017 21:14:50 -0700 Subject: [PATCH 239/512] [SPARK-20262][SQL] AssertNotNull should throw NullPointerException ## What changes were proposed in this pull request? AssertNotNull currently throws RuntimeException. It should throw NullPointerException, which is more specific. ## How was this patch tested? N/A Author: Reynold Xin Closes #17573 from rxin/SPARK-20262. --- .../spark/sql/catalyst/expressions/objects/objects.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 00e2ac91e67c..53842ef348a5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -989,7 +989,7 @@ case class InitializeJavaBean(beanInstance: Expression, setters: Map[String, Exp * `Int` field named `i`. Expression `s.i` is nullable because `s` can be null. However, for all * non-null `s`, `s.i` can't be null. */ -case class AssertNotNull(child: Expression, walkedTypePath: Seq[String]) +case class AssertNotNull(child: Expression, walkedTypePath: Seq[String] = Nil) extends UnaryExpression with NonSQLExpression { override def dataType: DataType = child.dataType @@ -1005,7 +1005,7 @@ case class AssertNotNull(child: Expression, walkedTypePath: Seq[String]) override def eval(input: InternalRow): Any = { val result = child.eval(input) if (result == null) { - throw new RuntimeException(errMsg) + throw new NullPointerException(errMsg) } result } @@ -1021,7 +1021,7 @@ case class AssertNotNull(child: Expression, walkedTypePath: Seq[String]) ${childGen.code} if (${childGen.isNull}) { - throw new RuntimeException($errMsgField); + throw new NullPointerException($errMsgField); } """ ev.copy(code = code, isNull = "false", value = childGen.value) From 34fc48fb5976ede00f3f6d8c4d3eec979e4f4d7f Mon Sep 17 00:00:00 2001 From: asmith26 Date: Sun, 9 Apr 2017 07:47:23 +0100 Subject: [PATCH 240/512] [MINOR] Issue: Change "slice" vs "partition" in exception messages (and code?) ## What changes were proposed in this pull request? Came across the term "slice" when running some spark scala code. Consequently, a Google search indicated that "slices" and "partitions" refer to the same things; indeed see: - [This issue](https://issues.apache.org/jira/browse/SPARK-1701) - [This pull request](https://github.com/apache/spark/pull/2305) - [This StackOverflow answer](http://stackoverflow.com/questions/23436640/what-is-the-difference-between-an-rdd-partition-and-a-slice) and [this one](http://stackoverflow.com/questions/24269495/what-are-the-differences-between-slices-and-partitions-of-rdds) Thus this pull request fixes the occurrence of slice I came accross. Nonetheless, [it would appear](https://github.com/apache/spark/search?utf8=%E2%9C%93&q=slice&type=) there are still many references to "slice/slices" - thus I thought I'd raise this Pull Request to address the issue (sorry if this is the wrong place, I'm not too familar with raising apache issues). ## How was this patch tested? (Not tested locally - only a minor exception message change.) Please review http://spark.apache.org/contributing.html before opening a pull request. Author: asmith26 Closes #17565 from asmith26/master. --- .../main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala | 2 +- .../src/main/java/org/apache/spark/examples/JavaSparkPi.java | 2 +- examples/src/main/java/org/apache/spark/examples/JavaTC.java | 2 +- .../main/scala/org/apache/spark/examples/BroadcastTest.scala | 2 +- .../scala/org/apache/spark/examples/MultiBroadcastTest.scala | 2 +- .../src/main/scala/org/apache/spark/examples/SparkALS.scala | 2 +- examples/src/main/scala/org/apache/spark/examples/SparkLR.scala | 2 +- 7 files changed, 7 insertions(+), 7 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala index e9092739b298..9f8019b80a4d 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala @@ -116,7 +116,7 @@ private object ParallelCollectionRDD { */ def slice[T: ClassTag](seq: Seq[T], numSlices: Int): Seq[Seq[T]] = { if (numSlices < 1) { - throw new IllegalArgumentException("Positive number of slices required") + throw new IllegalArgumentException("Positive number of partitions required") } // Sequences need to be sliced at the same set of index positions for operations // like RDD.zip() to behave as expected diff --git a/examples/src/main/java/org/apache/spark/examples/JavaSparkPi.java b/examples/src/main/java/org/apache/spark/examples/JavaSparkPi.java index cb4b26569088..37bd8fffbe45 100644 --- a/examples/src/main/java/org/apache/spark/examples/JavaSparkPi.java +++ b/examples/src/main/java/org/apache/spark/examples/JavaSparkPi.java @@ -26,7 +26,7 @@ /** * Computes an approximation to pi - * Usage: JavaSparkPi [slices] + * Usage: JavaSparkPi [partitions] */ public final class JavaSparkPi { diff --git a/examples/src/main/java/org/apache/spark/examples/JavaTC.java b/examples/src/main/java/org/apache/spark/examples/JavaTC.java index bde30b84d6cf..c9ca9c9b3a41 100644 --- a/examples/src/main/java/org/apache/spark/examples/JavaTC.java +++ b/examples/src/main/java/org/apache/spark/examples/JavaTC.java @@ -32,7 +32,7 @@ /** * Transitive closure on a graph, implemented in Java. - * Usage: JavaTC [slices] + * Usage: JavaTC [partitions] */ public final class JavaTC { diff --git a/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala b/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala index 86eed3867c53..25718f904cc4 100644 --- a/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala @@ -21,7 +21,7 @@ package org.apache.spark.examples import org.apache.spark.sql.SparkSession /** - * Usage: BroadcastTest [slices] [numElem] [blockSize] + * Usage: BroadcastTest [partitions] [numElem] [blockSize] */ object BroadcastTest { def main(args: Array[String]) { diff --git a/examples/src/main/scala/org/apache/spark/examples/MultiBroadcastTest.scala b/examples/src/main/scala/org/apache/spark/examples/MultiBroadcastTest.scala index 6495a86fcd77..e6f33b7adf5d 100644 --- a/examples/src/main/scala/org/apache/spark/examples/MultiBroadcastTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/MultiBroadcastTest.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.SparkSession /** - * Usage: MultiBroadcastTest [slices] [numElem] + * Usage: MultiBroadcastTest [partitions] [numElem] */ object MultiBroadcastTest { def main(args: Array[String]) { diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala b/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala index 8a3d08f45978..a99ddd9fd37d 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala @@ -100,7 +100,7 @@ object SparkALS { ITERATIONS = iters.getOrElse("5").toInt slices = slices_.getOrElse("2").toInt case _ => - System.err.println("Usage: SparkALS [M] [U] [F] [iters] [slices]") + System.err.println("Usage: SparkALS [M] [U] [F] [iters] [partitions]") System.exit(1) } diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala b/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala index afa8f58c96e5..cb2be091ffcf 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.SparkSession /** * Logistic regression based classification. - * Usage: SparkLR [slices] + * Usage: SparkLR [partitions] * * This is an example implementation for learning how to use Spark. For more conventional use, * please refer to org.apache.spark.ml.classification.LogisticRegression. From 1f0de3c1c85a41eadc7c4131bdc948405f340099 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Sun, 9 Apr 2017 08:44:02 +0100 Subject: [PATCH 241/512] [SPARK-19991][CORE][YARN] FileSegmentManagedBuffer performance improvement ## What changes were proposed in this pull request? Avoid `NoSuchElementException` every time `ConfigProvider.get(val, default)` falls back to default. This apparently causes non-trivial overhead in at least one path, and can easily be avoided. See https://github.com/apache/spark/pull/17329 ## How was this patch tested? Existing tests Author: Sean Owen Closes #17567 from srowen/SPARK-19991. --- .../org/apache/spark/network/util/MapConfigProvider.java | 6 ++++++ .../spark/network/yarn/util/HadoopConfigProvider.java | 6 ++++++ .../org/apache/spark/network/netty/SparkTransportConf.scala | 2 +- 3 files changed, 13 insertions(+), 1 deletion(-) diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/MapConfigProvider.java b/common/network-common/src/main/java/org/apache/spark/network/util/MapConfigProvider.java index 9cfee7f08d15..a2cf87d1af7e 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/util/MapConfigProvider.java +++ b/common/network-common/src/main/java/org/apache/spark/network/util/MapConfigProvider.java @@ -42,6 +42,12 @@ public String get(String name) { return value; } + @Override + public String get(String name, String defaultValue) { + String value = config.get(name); + return value == null ? defaultValue : value; + } + @Override public Iterable> getAll() { return config.entrySet(); diff --git a/common/network-yarn/src/main/java/org/apache/spark/network/yarn/util/HadoopConfigProvider.java b/common/network-yarn/src/main/java/org/apache/spark/network/yarn/util/HadoopConfigProvider.java index 62a6cca4ed4e..8beb03369947 100644 --- a/common/network-yarn/src/main/java/org/apache/spark/network/yarn/util/HadoopConfigProvider.java +++ b/common/network-yarn/src/main/java/org/apache/spark/network/yarn/util/HadoopConfigProvider.java @@ -41,6 +41,12 @@ public String get(String name) { return value; } + @Override + public String get(String name, String defaultValue) { + String value = conf.get(name); + return value == null ? defaultValue : value; + } + @Override public Iterable> getAll() { return conf; diff --git a/core/src/main/scala/org/apache/spark/network/netty/SparkTransportConf.scala b/core/src/main/scala/org/apache/spark/network/netty/SparkTransportConf.scala index df520f804b4c..25f7bcb9801b 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/SparkTransportConf.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/SparkTransportConf.scala @@ -60,7 +60,7 @@ object SparkTransportConf { new TransportConf(module, new ConfigProvider { override def get(name: String): String = conf.get(name) - + override def get(name: String, defaultValue: String): String = conf.get(name, defaultValue) override def getAll(): java.lang.Iterable[java.util.Map.Entry[String, String]] = { conf.getAll.toMap.asJava.entrySet() } From 261eaf5149a8fe479ab4f9c34db892bcedbf5739 Mon Sep 17 00:00:00 2001 From: Vijay Ramesh Date: Sun, 9 Apr 2017 19:39:09 +0100 Subject: [PATCH 242/512] [SPARK-20260][MLLIB] String interpolation required for error message ## What changes were proposed in this pull request? This error message doesn't get properly formatted because of a missing `s`. Currently the error looks like: ``` Caused by: java.lang.IllegalArgumentException: requirement failed: indices should be one-based and in ascending order; found current=$current, previous=$previous; line="$line" ``` (note the literal `$current` instead of the interpolated value) Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Vijay Ramesh Closes #17572 from vijaykramesh/master. --- .../scala/org/apache/spark/deploy/SparkHadoopUtil.scala | 2 +- .../test/scala/org/apache/spark/ml/util/TestingUtils.scala | 2 +- .../spark/mllib/clustering/PowerIterationClustering.scala | 4 ++-- .../apache/spark/mllib/tree/model/DecisionTreeModel.scala | 2 +- .../main/scala/org/apache/spark/mllib/util/MLUtils.scala | 2 +- .../scala/org/apache/spark/mllib/util/TestingUtils.scala | 2 +- .../scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala | 6 +++--- 7 files changed, 10 insertions(+), 10 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala index f475ce87540a..bae7a3f307f5 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala @@ -349,7 +349,7 @@ class SparkHadoopUtil extends Logging { } } catch { case e: IOException => - logDebug("Failed to decode $token: $e", e) + logDebug(s"Failed to decode $token: $e", e) } buffer.toString } diff --git a/mllib-local/src/test/scala/org/apache/spark/ml/util/TestingUtils.scala b/mllib-local/src/test/scala/org/apache/spark/ml/util/TestingUtils.scala index 30edd00fb53e..6c79d77f142e 100644 --- a/mllib-local/src/test/scala/org/apache/spark/ml/util/TestingUtils.scala +++ b/mllib-local/src/test/scala/org/apache/spark/ml/util/TestingUtils.scala @@ -215,7 +215,7 @@ object TestingUtils { if (r.fun(x, r.y, r.eps)) { throw new TestFailedException( s"Did not expect \n$x\n and \n${r.y}\n to be within " + - "${r.eps}${r.method} for all elements.", 0) + s"${r.eps}${r.method} for all elements.", 0) } true } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala index 4d3e265455da..b2437b845f82 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala @@ -259,7 +259,7 @@ object PowerIterationClustering extends Logging { val j = ctx.dstId val s = ctx.attr if (s < 0.0) { - throw new SparkException("Similarity must be nonnegative but found s($i, $j) = $s.") + throw new SparkException(s"Similarity must be nonnegative but found s($i, $j) = $s.") } if (s > 0.0) { ctx.sendToSrc(s) @@ -283,7 +283,7 @@ object PowerIterationClustering extends Logging { : Graph[Double, Double] = { val edges = similarities.flatMap { case (i, j, s) => if (s < 0.0) { - throw new SparkException("Similarity must be nonnegative but found s($i, $j) = $s.") + throw new SparkException(s"Similarity must be nonnegative but found s($i, $j) = $s.") } if (i != j) { Seq(Edge(i, j, s), Edge(j, i, s)) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala index a1562384b0a7..27618e122aef 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala @@ -248,7 +248,7 @@ object DecisionTreeModel extends Loader[DecisionTreeModel] with Logging { // Build node data into a tree. val trees = constructTrees(nodes) assert(trees.length == 1, - "Decision tree should contain exactly one tree but got ${trees.size} trees.") + s"Decision tree should contain exactly one tree but got ${trees.size} trees.") val model = new DecisionTreeModel(trees(0), Algo.fromString(algo)) assert(model.numNodes == numNodes, s"Unable to load DecisionTreeModel data from: $dataPath." + s" Expected $numNodes nodes but found ${model.numNodes}") diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala index 95f904dac552..4fdad0597396 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala @@ -119,7 +119,7 @@ object MLUtils extends Logging { while (i < indicesLength) { val current = indices(i) require(current > previous, s"indices should be one-based and in ascending order;" - + " found current=$current, previous=$previous; line=\"$line\"") + + s""" found current=$current, previous=$previous; line="$line"""") previous = current i += 1 } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtils.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtils.scala index 39a6bc37d963..d39865a19a5c 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtils.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtils.scala @@ -207,7 +207,7 @@ object TestingUtils { if (r.fun(x, r.y, r.eps)) { throw new TestFailedException( s"Did not expect \n$x\n and \n${r.y}\n to be within " + - "${r.eps}${r.method} for all elements.", 0) + s"${r.eps}${r.method} for all elements.", 0) } true } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala index 5d8ba9d7c85d..8c855730c31f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala @@ -285,7 +285,7 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { val queryOutput = selfJoin.queryExecution.analyzed.output assertResult(4, "Field count mismatches")(queryOutput.size) - assertResult(2, "Duplicated expression ID in query plan:\n $selfJoin") { + assertResult(2, s"Duplicated expression ID in query plan:\n $selfJoin") { queryOutput.filter(_.name == "_1").map(_.exprId).size } @@ -294,7 +294,7 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { } test("nested data - struct with array field") { - val data = (1 to 10).map(i => Tuple1((i, Seq("val_$i")))) + val data = (1 to 10).map(i => Tuple1((i, Seq(s"val_$i")))) withOrcTable(data, "t") { checkAnswer(sql("SELECT `_1`.`_2`[0] FROM t"), data.map { case Tuple1((_, Seq(string))) => Row(string) @@ -303,7 +303,7 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { } test("nested data - array of struct") { - val data = (1 to 10).map(i => Tuple1(Seq(i -> "val_$i"))) + val data = (1 to 10).map(i => Tuple1(Seq(i -> s"val_$i"))) withOrcTable(data, "t") { checkAnswer(sql("SELECT `_1`[0].`_2` FROM t"), data.map { case Tuple1(Seq((_, string))) => Row(string) From 7a63f5e82758345ff1f3322950f2bbea350c48b9 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Mon, 10 Apr 2017 10:47:17 +0800 Subject: [PATCH 243/512] [SPARK-20253][SQL] Remove unnecessary nullchecks of a return value from Spark runtime routines in generated Java code ## What changes were proposed in this pull request? This PR elminates unnecessary nullchecks of a return value from known Spark runtime routines. We know whether a given Spark runtime routine returns ``null`` or not (e.g. ``ArrayData.toDoubleArray()`` never returns ``null``). Thus, we can eliminate a null check for the return value from the Spark runtime routine. When we run the following example program, now we get the Java code "Without this PR". In this code, since we know ``ArrayData.toDoubleArray()`` never returns ``null```, we can eliminate null checks at lines 90-92, and 97. ```java val ds = sparkContext.parallelize(Seq(Array(1.1, 2.2)), 1).toDS.cache ds.count ds.map(e => e).show ``` Without this PR ```java /* 050 */ protected void processNext() throws java.io.IOException { /* 051 */ while (inputadapter_input.hasNext() && !stopEarly()) { /* 052 */ InternalRow inputadapter_row = (InternalRow) inputadapter_input.next(); /* 053 */ boolean inputadapter_isNull = inputadapter_row.isNullAt(0); /* 054 */ ArrayData inputadapter_value = inputadapter_isNull ? null : (inputadapter_row.getArray(0)); /* 055 */ /* 056 */ ArrayData deserializetoobject_value1 = null; /* 057 */ /* 058 */ if (!inputadapter_isNull) { /* 059 */ int deserializetoobject_dataLength = inputadapter_value.numElements(); /* 060 */ /* 061 */ Double[] deserializetoobject_convertedArray = null; /* 062 */ deserializetoobject_convertedArray = new Double[deserializetoobject_dataLength]; /* 063 */ /* 064 */ int deserializetoobject_loopIndex = 0; /* 065 */ while (deserializetoobject_loopIndex < deserializetoobject_dataLength) { /* 066 */ MapObjects_loopValue2 = (double) (inputadapter_value.getDouble(deserializetoobject_loopIndex)); /* 067 */ MapObjects_loopIsNull2 = inputadapter_value.isNullAt(deserializetoobject_loopIndex); /* 068 */ /* 069 */ if (MapObjects_loopIsNull2) { /* 070 */ throw new RuntimeException(((java.lang.String) references[0])); /* 071 */ } /* 072 */ if (false) { /* 073 */ deserializetoobject_convertedArray[deserializetoobject_loopIndex] = null; /* 074 */ } else { /* 075 */ deserializetoobject_convertedArray[deserializetoobject_loopIndex] = MapObjects_loopValue2; /* 076 */ } /* 077 */ /* 078 */ deserializetoobject_loopIndex += 1; /* 079 */ } /* 080 */ /* 081 */ deserializetoobject_value1 = new org.apache.spark.sql.catalyst.util.GenericArrayData(deserializetoobject_convertedArray); /*###*/ /* 082 */ } /* 083 */ boolean deserializetoobject_isNull = true; /* 084 */ double[] deserializetoobject_value = null; /* 085 */ if (!inputadapter_isNull) { /* 086 */ deserializetoobject_isNull = false; /* 087 */ if (!deserializetoobject_isNull) { /* 088 */ Object deserializetoobject_funcResult = null; /* 089 */ deserializetoobject_funcResult = deserializetoobject_value1.toDoubleArray(); /* 090 */ if (deserializetoobject_funcResult == null) { /* 091 */ deserializetoobject_isNull = true; /* 092 */ } else { /* 093 */ deserializetoobject_value = (double[]) deserializetoobject_funcResult; /* 094 */ } /* 095 */ /* 096 */ } /* 097 */ deserializetoobject_isNull = deserializetoobject_value == null; /* 098 */ } /* 099 */ /* 100 */ boolean mapelements_isNull = true; /* 101 */ double[] mapelements_value = null; /* 102 */ if (!false) { /* 103 */ mapelements_resultIsNull = false; /* 104 */ /* 105 */ if (!mapelements_resultIsNull) { /* 106 */ mapelements_resultIsNull = deserializetoobject_isNull; /* 107 */ mapelements_argValue = deserializetoobject_value; /* 108 */ } /* 109 */ /* 110 */ mapelements_isNull = mapelements_resultIsNull; /* 111 */ if (!mapelements_isNull) { /* 112 */ Object mapelements_funcResult = null; /* 113 */ mapelements_funcResult = ((scala.Function1) references[1]).apply(mapelements_argValue); /* 114 */ if (mapelements_funcResult == null) { /* 115 */ mapelements_isNull = true; /* 116 */ } else { /* 117 */ mapelements_value = (double[]) mapelements_funcResult; /* 118 */ } /* 119 */ /* 120 */ } /* 121 */ mapelements_isNull = mapelements_value == null; /* 122 */ } /* 123 */ /* 124 */ serializefromobject_resultIsNull = false; /* 125 */ /* 126 */ if (!serializefromobject_resultIsNull) { /* 127 */ serializefromobject_resultIsNull = mapelements_isNull; /* 128 */ serializefromobject_argValue = mapelements_value; /* 129 */ } /* 130 */ /* 131 */ boolean serializefromobject_isNull = serializefromobject_resultIsNull; /* 132 */ final ArrayData serializefromobject_value = serializefromobject_resultIsNull ? null : org.apache.spark.sql.catalyst.expressions.UnsafeArrayData.fromPrimitiveArray(serializefromobject_argValue); /* 133 */ serializefromobject_isNull = serializefromobject_value == null; /* 134 */ serializefromobject_holder.reset(); /* 135 */ /* 136 */ serializefromobject_rowWriter.zeroOutNullBytes(); /* 137 */ /* 138 */ if (serializefromobject_isNull) { /* 139 */ serializefromobject_rowWriter.setNullAt(0); /* 140 */ } else { /* 141 */ // Remember the current cursor so that we can calculate how many bytes are /* 142 */ // written later. /* 143 */ final int serializefromobject_tmpCursor = serializefromobject_holder.cursor; /* 144 */ /* 145 */ if (serializefromobject_value instanceof UnsafeArrayData) { /* 146 */ final int serializefromobject_sizeInBytes = ((UnsafeArrayData) serializefromobject_value).getSizeInBytes(); /* 147 */ // grow the global buffer before writing data. /* 148 */ serializefromobject_holder.grow(serializefromobject_sizeInBytes); /* 149 */ ((UnsafeArrayData) serializefromobject_value).writeToMemory(serializefromobject_holder.buffer, serializefromobject_holder.cursor); /* 150 */ serializefromobject_holder.cursor += serializefromobject_sizeInBytes; /* 151 */ /* 152 */ } else { /* 153 */ final int serializefromobject_numElements = serializefromobject_value.numElements(); /* 154 */ serializefromobject_arrayWriter.initialize(serializefromobject_holder, serializefromobject_numElements, 8); /* 155 */ /* 156 */ for (int serializefromobject_index = 0; serializefromobject_index < serializefromobject_numElements; serializefromobject_index++) { /* 157 */ if (serializefromobject_value.isNullAt(serializefromobject_index)) { /* 158 */ serializefromobject_arrayWriter.setNullDouble(serializefromobject_index); /* 159 */ } else { /* 160 */ final double serializefromobject_element = serializefromobject_value.getDouble(serializefromobject_index); /* 161 */ serializefromobject_arrayWriter.write(serializefromobject_index, serializefromobject_element); /* 162 */ } /* 163 */ } /* 164 */ } /* 165 */ /* 166 */ serializefromobject_rowWriter.setOffsetAndSize(0, serializefromobject_tmpCursor, serializefromobject_holder.cursor - serializefromobject_tmpCursor); /* 167 */ } /* 168 */ serializefromobject_result.setTotalSize(serializefromobject_holder.totalSize()); /* 169 */ append(serializefromobject_result); /* 170 */ if (shouldStop()) return; /* 171 */ } /* 172 */ } ``` With this PR (removed most of lines 90-97 in the above code) ```java /* 050 */ protected void processNext() throws java.io.IOException { /* 051 */ while (inputadapter_input.hasNext() && !stopEarly()) { /* 052 */ InternalRow inputadapter_row = (InternalRow) inputadapter_input.next(); /* 053 */ boolean inputadapter_isNull = inputadapter_row.isNullAt(0); /* 054 */ ArrayData inputadapter_value = inputadapter_isNull ? null : (inputadapter_row.getArray(0)); /* 055 */ /* 056 */ ArrayData deserializetoobject_value1 = null; /* 057 */ /* 058 */ if (!inputadapter_isNull) { /* 059 */ int deserializetoobject_dataLength = inputadapter_value.numElements(); /* 060 */ /* 061 */ Double[] deserializetoobject_convertedArray = null; /* 062 */ deserializetoobject_convertedArray = new Double[deserializetoobject_dataLength]; /* 063 */ /* 064 */ int deserializetoobject_loopIndex = 0; /* 065 */ while (deserializetoobject_loopIndex < deserializetoobject_dataLength) { /* 066 */ MapObjects_loopValue2 = (double) (inputadapter_value.getDouble(deserializetoobject_loopIndex)); /* 067 */ MapObjects_loopIsNull2 = inputadapter_value.isNullAt(deserializetoobject_loopIndex); /* 068 */ /* 069 */ if (MapObjects_loopIsNull2) { /* 070 */ throw new RuntimeException(((java.lang.String) references[0])); /* 071 */ } /* 072 */ if (false) { /* 073 */ deserializetoobject_convertedArray[deserializetoobject_loopIndex] = null; /* 074 */ } else { /* 075 */ deserializetoobject_convertedArray[deserializetoobject_loopIndex] = MapObjects_loopValue2; /* 076 */ } /* 077 */ /* 078 */ deserializetoobject_loopIndex += 1; /* 079 */ } /* 080 */ /* 081 */ deserializetoobject_value1 = new org.apache.spark.sql.catalyst.util.GenericArrayData(deserializetoobject_convertedArray); /*###*/ /* 082 */ } /* 083 */ boolean deserializetoobject_isNull = true; /* 084 */ double[] deserializetoobject_value = null; /* 085 */ if (!inputadapter_isNull) { /* 086 */ deserializetoobject_isNull = false; /* 087 */ if (!deserializetoobject_isNull) { /* 088 */ Object deserializetoobject_funcResult = null; /* 089 */ deserializetoobject_funcResult = deserializetoobject_value1.toDoubleArray(); /* 090 */ deserializetoobject_value = (double[]) deserializetoobject_funcResult; /* 091 */ /* 092 */ } /* 093 */ /* 094 */ } /* 095 */ /* 096 */ boolean mapelements_isNull = true; /* 097 */ double[] mapelements_value = null; /* 098 */ if (!false) { /* 099 */ mapelements_resultIsNull = false; /* 100 */ /* 101 */ if (!mapelements_resultIsNull) { /* 102 */ mapelements_resultIsNull = deserializetoobject_isNull; /* 103 */ mapelements_argValue = deserializetoobject_value; /* 104 */ } /* 105 */ /* 106 */ mapelements_isNull = mapelements_resultIsNull; /* 107 */ if (!mapelements_isNull) { /* 108 */ Object mapelements_funcResult = null; /* 109 */ mapelements_funcResult = ((scala.Function1) references[1]).apply(mapelements_argValue); /* 110 */ if (mapelements_funcResult == null) { /* 111 */ mapelements_isNull = true; /* 112 */ } else { /* 113 */ mapelements_value = (double[]) mapelements_funcResult; /* 114 */ } /* 115 */ /* 116 */ } /* 117 */ mapelements_isNull = mapelements_value == null; /* 118 */ } /* 119 */ /* 120 */ serializefromobject_resultIsNull = false; /* 121 */ /* 122 */ if (!serializefromobject_resultIsNull) { /* 123 */ serializefromobject_resultIsNull = mapelements_isNull; /* 124 */ serializefromobject_argValue = mapelements_value; /* 125 */ } /* 126 */ /* 127 */ boolean serializefromobject_isNull = serializefromobject_resultIsNull; /* 128 */ final ArrayData serializefromobject_value = serializefromobject_resultIsNull ? null : org.apache.spark.sql.catalyst.expressions.UnsafeArrayData.fromPrimitiveArray(serializefromobject_argValue); /* 129 */ serializefromobject_isNull = serializefromobject_value == null; /* 130 */ serializefromobject_holder.reset(); /* 131 */ /* 132 */ serializefromobject_rowWriter.zeroOutNullBytes(); /* 133 */ /* 134 */ if (serializefromobject_isNull) { /* 135 */ serializefromobject_rowWriter.setNullAt(0); /* 136 */ } else { /* 137 */ // Remember the current cursor so that we can calculate how many bytes are /* 138 */ // written later. /* 139 */ final int serializefromobject_tmpCursor = serializefromobject_holder.cursor; /* 140 */ /* 141 */ if (serializefromobject_value instanceof UnsafeArrayData) { /* 142 */ final int serializefromobject_sizeInBytes = ((UnsafeArrayData) serializefromobject_value).getSizeInBytes(); /* 143 */ // grow the global buffer before writing data. /* 144 */ serializefromobject_holder.grow(serializefromobject_sizeInBytes); /* 145 */ ((UnsafeArrayData) serializefromobject_value).writeToMemory(serializefromobject_holder.buffer, serializefromobject_holder.cursor); /* 146 */ serializefromobject_holder.cursor += serializefromobject_sizeInBytes; /* 147 */ /* 148 */ } else { /* 149 */ final int serializefromobject_numElements = serializefromobject_value.numElements(); /* 150 */ serializefromobject_arrayWriter.initialize(serializefromobject_holder, serializefromobject_numElements, 8); /* 151 */ /* 152 */ for (int serializefromobject_index = 0; serializefromobject_index < serializefromobject_numElements; serializefromobject_index++) { /* 153 */ if (serializefromobject_value.isNullAt(serializefromobject_index)) { /* 154 */ serializefromobject_arrayWriter.setNullDouble(serializefromobject_index); /* 155 */ } else { /* 156 */ final double serializefromobject_element = serializefromobject_value.getDouble(serializefromobject_index); /* 157 */ serializefromobject_arrayWriter.write(serializefromobject_index, serializefromobject_element); /* 158 */ } /* 159 */ } /* 160 */ } /* 161 */ /* 162 */ serializefromobject_rowWriter.setOffsetAndSize(0, serializefromobject_tmpCursor, serializefromobject_holder.cursor - serializefromobject_tmpCursor); /* 163 */ } /* 164 */ serializefromobject_result.setTotalSize(serializefromobject_holder.totalSize()); /* 165 */ append(serializefromobject_result); /* 166 */ if (shouldStop()) return; /* 167 */ } /* 168 */ } ``` ## How was this patch tested? Add test suites to ``DatasetPrimitiveSuite`` Author: Kazuaki Ishizaki Closes #17569 from kiszk/SPARK-20253. --- .../spark/sql/catalyst/ScalaReflection.scala | 27 ++++++++++-------- .../sql/catalyst/encoders/RowEncoder.scala | 19 +++++++------ .../expressions/objects/objects.scala | 28 +++++++++---------- .../spark/sql/DatasetPrimitiveSuite.scala | 10 +++++++ 4 files changed, 51 insertions(+), 33 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 206ae2f0e5eb..198122759e4a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -251,19 +251,22 @@ object ScalaReflection extends ScalaReflection { getPath :: Nil) case t if t <:< localTypeOf[java.lang.String] => - Invoke(getPath, "toString", ObjectType(classOf[String])) + Invoke(getPath, "toString", ObjectType(classOf[String]), returnNullable = false) case t if t <:< localTypeOf[java.math.BigDecimal] => - Invoke(getPath, "toJavaBigDecimal", ObjectType(classOf[java.math.BigDecimal])) + Invoke(getPath, "toJavaBigDecimal", ObjectType(classOf[java.math.BigDecimal]), + returnNullable = false) case t if t <:< localTypeOf[BigDecimal] => - Invoke(getPath, "toBigDecimal", ObjectType(classOf[BigDecimal])) + Invoke(getPath, "toBigDecimal", ObjectType(classOf[BigDecimal]), returnNullable = false) case t if t <:< localTypeOf[java.math.BigInteger] => - Invoke(getPath, "toJavaBigInteger", ObjectType(classOf[java.math.BigInteger])) + Invoke(getPath, "toJavaBigInteger", ObjectType(classOf[java.math.BigInteger]), + returnNullable = false) case t if t <:< localTypeOf[scala.math.BigInt] => - Invoke(getPath, "toScalaBigInt", ObjectType(classOf[scala.math.BigInt])) + Invoke(getPath, "toScalaBigInt", ObjectType(classOf[scala.math.BigInt]), + returnNullable = false) case t if t <:< localTypeOf[Array[_]] => val TypeRef(_, _, Seq(elementType)) = t @@ -284,7 +287,7 @@ object ScalaReflection extends ScalaReflection { val arrayCls = arrayClassFor(elementType) if (elementNullable) { - Invoke(arrayData, "array", arrayCls) + Invoke(arrayData, "array", arrayCls, returnNullable = false) } else { val primitiveMethod = elementType match { case t if t <:< definitions.IntTpe => "toIntArray" @@ -297,7 +300,7 @@ object ScalaReflection extends ScalaReflection { case other => throw new IllegalStateException("expect primitive array element type " + "but got " + other) } - Invoke(arrayData, primitiveMethod, arrayCls) + Invoke(arrayData, primitiveMethod, arrayCls, returnNullable = false) } case t if t <:< localTypeOf[Seq[_]] => @@ -330,19 +333,21 @@ object ScalaReflection extends ScalaReflection { Invoke( MapObjects( p => deserializerFor(keyType, Some(p), walkedTypePath), - Invoke(getPath, "keyArray", ArrayType(schemaFor(keyType).dataType)), + Invoke(getPath, "keyArray", ArrayType(schemaFor(keyType).dataType), + returnNullable = false), schemaFor(keyType).dataType), "array", - ObjectType(classOf[Array[Any]])) + ObjectType(classOf[Array[Any]]), returnNullable = false) val valueData = Invoke( MapObjects( p => deserializerFor(valueType, Some(p), walkedTypePath), - Invoke(getPath, "valueArray", ArrayType(schemaFor(valueType).dataType)), + Invoke(getPath, "valueArray", ArrayType(schemaFor(valueType).dataType), + returnNullable = false), schemaFor(valueType).dataType), "array", - ObjectType(classOf[Array[Any]])) + ObjectType(classOf[Array[Any]]), returnNullable = false) StaticInvoke( ArrayBasedMapData.getClass, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index e95e97b9dc6c..0f8282d3b2f1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -89,7 +89,7 @@ object RowEncoder { udtClass, Nil, dataType = ObjectType(udtClass), false) - Invoke(obj, "serialize", udt, inputObject :: Nil) + Invoke(obj, "serialize", udt, inputObject :: Nil, returnNullable = false) case TimestampType => StaticInvoke( @@ -136,16 +136,18 @@ object RowEncoder { case t @ MapType(kt, vt, valueNullable) => val keys = Invoke( - Invoke(inputObject, "keysIterator", ObjectType(classOf[scala.collection.Iterator[_]])), + Invoke(inputObject, "keysIterator", ObjectType(classOf[scala.collection.Iterator[_]]), + returnNullable = false), "toSeq", - ObjectType(classOf[scala.collection.Seq[_]])) + ObjectType(classOf[scala.collection.Seq[_]]), returnNullable = false) val convertedKeys = serializerFor(keys, ArrayType(kt, false)) val values = Invoke( - Invoke(inputObject, "valuesIterator", ObjectType(classOf[scala.collection.Iterator[_]])), + Invoke(inputObject, "valuesIterator", ObjectType(classOf[scala.collection.Iterator[_]]), + returnNullable = false), "toSeq", - ObjectType(classOf[scala.collection.Seq[_]])) + ObjectType(classOf[scala.collection.Seq[_]]), returnNullable = false) val convertedValues = serializerFor(values, ArrayType(vt, valueNullable)) NewInstance( @@ -262,17 +264,18 @@ object RowEncoder { input :: Nil) case _: DecimalType => - Invoke(input, "toJavaBigDecimal", ObjectType(classOf[java.math.BigDecimal])) + Invoke(input, "toJavaBigDecimal", ObjectType(classOf[java.math.BigDecimal]), + returnNullable = false) case StringType => - Invoke(input, "toString", ObjectType(classOf[String])) + Invoke(input, "toString", ObjectType(classOf[String]), returnNullable = false) case ArrayType(et, nullable) => val arrayData = Invoke( MapObjects(deserializerFor(_), input, et), "array", - ObjectType(classOf[Array[_]])) + ObjectType(classOf[Array[_]]), returnNullable = false) StaticInvoke( scala.collection.mutable.WrappedArray.getClass, ObjectType(classOf[Seq[_]]), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 53842ef348a5..6d94764f1bfa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -225,25 +225,26 @@ case class Invoke( getFuncResult(ev.value, s"${obj.value}.$functionName($argString)") } else { val funcResult = ctx.freshName("funcResult") + // If the function can return null, we do an extra check to make sure our null bit is still + // set correctly. + val assignResult = if (!returnNullable) { + s"${ev.value} = (${ctx.boxedType(javaType)}) $funcResult;" + } else { + s""" + if ($funcResult != null) { + ${ev.value} = (${ctx.boxedType(javaType)}) $funcResult; + } else { + ${ev.isNull} = true; + } + """ + } s""" Object $funcResult = null; ${getFuncResult(funcResult, s"${obj.value}.$functionName($argString)")} - if ($funcResult == null) { - ${ev.isNull} = true; - } else { - ${ev.value} = (${ctx.boxedType(javaType)}) $funcResult; - } + $assignResult """ } - // If the function can return null, we do an extra check to make sure our null bit is still set - // correctly. - val postNullCheck = if (ctx.defaultValue(dataType) == "null") { - s"${ev.isNull} = ${ev.value} == null;" - } else { - "" - } - val code = s""" ${obj.code} boolean ${ev.isNull} = true; @@ -254,7 +255,6 @@ case class Invoke( if (!${ev.isNull}) { $evaluate } - $postNullCheck } """ ev.copy(code = code) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala index 82b707537e45..541565344f75 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala @@ -96,6 +96,16 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext { checkDataset(dsBoolean.map(e => !e), false, true) } + test("mapPrimitiveArray") { + val dsInt = Seq(Array(1, 2), Array(3, 4)).toDS() + checkDataset(dsInt.map(e => e), Array(1, 2), Array(3, 4)) + checkDataset(dsInt.map(e => null: Array[Int]), null, null) + + val dsDouble = Seq(Array(1D, 2D), Array(3D, 4D)).toDS() + checkDataset(dsDouble.map(e => e), Array(1D, 2D), Array(3D, 4D)) + checkDataset(dsDouble.map(e => null: Array[Double]), null, null) + } + test("filter") { val ds = Seq(1, 2, 3, 4).toDS() checkDataset( From 7bfa05e0a5e6860a942e1ce47e7890d665acdfe3 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sun, 9 Apr 2017 20:32:07 -0700 Subject: [PATCH 244/512] [SPARK-20264][SQL] asm should be non-test dependency in sql/core ## What changes were proposed in this pull request? sq/core module currently declares asm as a test scope dependency. Transitively it should actually be a normal dependency since the actual core module defines it. This occasionally confuses IntelliJ. ## How was this patch tested? N/A - This is a build change. Author: Reynold Xin Closes #17574 from rxin/SPARK-20264. --- sql/core/pom.xml | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/sql/core/pom.xml b/sql/core/pom.xml index 69d797b47915..b203f31a76f0 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -103,6 +103,10 @@ jackson-databind ${fasterxml.jackson.version} + + org.apache.xbean + xbean-asm5-shaded + org.scalacheck scalacheck_${scala.binary.version} @@ -147,11 +151,6 @@ mockito-core test - - org.apache.xbean - xbean-asm5-shaded - test - target/scala-${scala.binary.version}/classes From 1a0bc41659eef317dcac18df35c26857216a4314 Mon Sep 17 00:00:00 2001 From: DB Tsai Date: Mon, 10 Apr 2017 05:16:34 +0000 Subject: [PATCH 245/512] [SPARK-20270][SQL] na.fill should not change the values in long or integer when the default value is in double ## What changes were proposed in this pull request? This bug was partially addressed in SPARK-18555 https://github.com/apache/spark/pull/15994, but the root cause isn't completely solved. This bug is pretty critical since it changes the member id in Long in our application if the member id can not be represented by Double losslessly when the member id is very big. Here is an example how this happens, with ``` Seq[(java.lang.Long, java.lang.Double)]((null, 3.14), (9123146099426677101L, null), (9123146560113991650L, 1.6), (null, null)).toDF("a", "b").na.fill(0.2), ``` the logical plan will be ``` == Analyzed Logical Plan == a: bigint, b: double Project [cast(coalesce(cast(a#232L as double), cast(0.2 as double)) as bigint) AS a#240L, cast(coalesce(nanvl(b#233, cast(null as double)), 0.2) as double) AS b#241] +- Project [_1#229L AS a#232L, _2#230 AS b#233] +- LocalRelation [_1#229L, _2#230] ``` Note that even the value is not null, Spark will cast the Long into Double first. Then if it's not null, Spark will cast it back to Long which results in losing precision. The behavior should be that the original value should not be changed if it's not null, but Spark will change the value which is wrong. With the PR, the logical plan will be ``` == Analyzed Logical Plan == a: bigint, b: double Project [coalesce(a#232L, cast(0.2 as bigint)) AS a#240L, coalesce(nanvl(b#233, cast(null as double)), cast(0.2 as double)) AS b#241] +- Project [_1#229L AS a#232L, _2#230 AS b#233] +- LocalRelation [_1#229L, _2#230] ``` which behaves correctly without changing the original Long values and also avoids extra cost of unnecessary casting. ## How was this patch tested? unit test added. +cc srowen rxin cloud-fan gatorsmile Thanks. Author: DB Tsai Closes #17577 from dbtsai/fixnafill. --- .../apache/spark/sql/DataFrameNaFunctions.scala | 5 +++-- .../spark/sql/DataFrameNaFunctionsSuite.scala | 14 ++++++++++++++ 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala index 28820681cd3a..d8f953fba5a8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala @@ -407,10 +407,11 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { val quotedColName = "`" + col.name + "`" val colValue = col.dataType match { case DoubleType | FloatType => - nanvl(df.col(quotedColName), lit(null)) // nanvl only supports these types + // nanvl only supports these types + nanvl(df.col(quotedColName), lit(null).cast(col.dataType)) case _ => df.col(quotedColName) } - coalesce(colValue, lit(replacement)).cast(col.dataType).as(col.name) + coalesce(colValue, lit(replacement).cast(col.dataType)).as(col.name) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala index fd829846ac33..aa237d0619ac 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala @@ -145,6 +145,20 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSQLContext { Row(1, 2) :: Row(-1, -2) :: Row(9123146099426677101L, 9123146560113991650L) :: Nil ) + checkAnswer( + Seq[(java.lang.Long, java.lang.Double)]((null, 3.14), (9123146099426677101L, null), + (9123146560113991650L, 1.6), (null, null)).toDF("a", "b").na.fill(0.2), + Row(0, 3.14) :: Row(9123146099426677101L, 0.2) :: Row(9123146560113991650L, 1.6) + :: Row(0, 0.2) :: Nil + ) + + checkAnswer( + Seq[(java.lang.Long, java.lang.Float)]((null, 3.14f), (9123146099426677101L, null), + (9123146560113991650L, 1.6f), (null, null)).toDF("a", "b").na.fill(0.2), + Row(0, 3.14f) :: Row(9123146099426677101L, 0.2f) :: Row(9123146560113991650L, 1.6f) + :: Row(0, 0.2f) :: Nil + ) + checkAnswer( Seq[(java.lang.Long, java.lang.Double)]((null, 1.23), (3L, null), (4L, 3.45)) .toDF("a", "b").na.fill(2.34), From 3d7f201f2adc2d33be6f564fa76435c18552f4ba Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 10 Apr 2017 13:36:08 +0800 Subject: [PATCH 246/512] [SPARK-20229][SQL] add semanticHash to QueryPlan ## What changes were proposed in this pull request? Like `Expression`, `QueryPlan` should also have a `semanticHash` method, then we can put plans to a hash map and look it up fast. This PR refactors `QueryPlan` to follow `Expression` and put all the normalization logic in `QueryPlan.canonicalized`, so that it's very natural to implement `semanticHash`. follow-up: improve `CacheManager` to leverage this `semanticHash` and speed up plan lookup, instead of iterating all cached plans. ## How was this patch tested? existing tests. Note that we don't need to test the `semanticHash` method, once the existing tests prove `sameResult` is correct, we are good. Author: Wenchen Fan Closes #17541 from cloud-fan/plan-semantic. --- .../sql/catalyst/analysis/Analyzer.scala | 2 +- .../sql/catalyst/catalog/interface.scala | 11 +- .../spark/sql/catalyst/plans/QueryPlan.scala | 102 +++++++++++------- .../plans/logical/LocalRelation.scala | 8 -- .../catalyst/plans/logical/LogicalPlan.scala | 2 - .../plans/logical/basicLogicalOperators.scala | 2 + .../plans/physical/broadcastMode.scala | 9 +- .../sql/execution/DataSourceScanExec.scala | 37 +++---- .../spark/sql/execution/ExistingRDD.scala | 14 --- .../sql/execution/LocalTableScanExec.scala | 2 +- .../execution/basicPhysicalOperators.scala | 10 +- .../datasources/LogicalRelation.scala | 13 +-- .../exchange/BroadcastExchangeExec.scala | 6 +- .../sql/execution/exchange/Exchange.scala | 6 +- .../sql/execution/joins/HashedRelation.scala | 11 +- .../spark/sql/execution/ExchangeSuite.scala | 18 ++-- .../hive/execution/HiveTableScanExec.scala | 45 ++++---- 17 files changed, 135 insertions(+), 163 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index c698ca6a8347..b0cdef70297c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -617,7 +617,7 @@ class Analyzer( def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case i @ InsertIntoTable(u: UnresolvedRelation, parts, child, _, _) if child.resolved => - lookupTableFromCatalog(u).canonicalized match { + EliminateSubqueryAliases(lookupTableFromCatalog(u)) match { case v: View => u.failAnalysis(s"Inserting into a view is not allowed. View: ${v.desc.identifier}.") case other => i.copy(table = other) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala index 360e55d92282..cc0cbba275b8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala @@ -423,8 +423,15 @@ case class CatalogRelation( Objects.hashCode(tableMeta.identifier, output) } - /** Only compare table identifier. */ - override lazy val cleanArgs: Seq[Any] = Seq(tableMeta.identifier) + override def preCanonicalized: LogicalPlan = copy(tableMeta = CatalogTable( + identifier = tableMeta.identifier, + tableType = tableMeta.tableType, + storage = CatalogStorageFormat.empty, + schema = tableMeta.schema, + partitionColumnNames = tableMeta.partitionColumnNames, + bucketSpec = tableMeta.bucketSpec, + createTime = -1 + )) override def computeStats(conf: SQLConf): Statistics = { // For data source tables, we will create a `LogicalRelation` and won't call this method, for diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 2d8ec2053a4c..3008e8cb8465 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -359,9 +359,59 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT override protected def innerChildren: Seq[QueryPlan[_]] = subqueries /** - * Canonicalized copy of this query plan. + * Returns a plan where a best effort attempt has been made to transform `this` in a way + * that preserves the result but removes cosmetic variations (case sensitivity, ordering for + * commutative operations, expression id, etc.) + * + * Plans where `this.canonicalized == other.canonicalized` will always evaluate to the same + * result. + * + * Some nodes should overwrite this to provide proper canonicalize logic. + */ + lazy val canonicalized: PlanType = { + val canonicalizedChildren = children.map(_.canonicalized) + var id = -1 + preCanonicalized.mapExpressions { + case a: Alias => + id += 1 + // As the root of the expression, Alias will always take an arbitrary exprId, we need to + // normalize that for equality testing, by assigning expr id from 0 incrementally. The + // alias name doesn't matter and should be erased. + Alias(normalizeExprId(a.child), "")(ExprId(id), a.qualifier, isGenerated = a.isGenerated) + + case ar: AttributeReference if allAttributes.indexOf(ar.exprId) == -1 => + // Top level `AttributeReference` may also be used for output like `Alias`, we should + // normalize the epxrId too. + id += 1 + ar.withExprId(ExprId(id)) + + case other => normalizeExprId(other) + }.withNewChildren(canonicalizedChildren) + } + + /** + * Do some simple transformation on this plan before canonicalizing. Implementations can override + * this method to provide customized canonicalize logic without rewriting the whole logic. */ - protected lazy val canonicalized: PlanType = this + protected def preCanonicalized: PlanType = this + + /** + * Normalize the exprIds in the given expression, by updating the exprId in `AttributeReference` + * with its referenced ordinal from input attributes. It's similar to `BindReferences` but we + * do not use `BindReferences` here as the plan may take the expression as a parameter with type + * `Attribute`, and replace it with `BoundReference` will cause error. + */ + protected def normalizeExprId[T <: Expression](e: T, input: AttributeSeq = allAttributes): T = { + e.transformUp { + case ar: AttributeReference => + val ordinal = input.indexOf(ar.exprId) + if (ordinal == -1) { + ar + } else { + ar.withExprId(ExprId(ordinal)) + } + }.canonicalized.asInstanceOf[T] + } /** * Returns true when the given query plan will return the same results as this query plan. @@ -372,49 +422,19 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT * enhancements like caching. However, it is not acceptable to return true if the results could * possibly be different. * - * By default this function performs a modified version of equality that is tolerant of cosmetic - * differences like attribute naming and or expression id differences. Operators that - * can do better should override this function. + * This function performs a modified version of equality that is tolerant of cosmetic + * differences like attribute naming and or expression id differences. */ - def sameResult(plan: PlanType): Boolean = { - val left = this.canonicalized - val right = plan.canonicalized - left.getClass == right.getClass && - left.children.size == right.children.size && - left.cleanArgs == right.cleanArgs && - (left.children, right.children).zipped.forall(_ sameResult _) - } + final def sameResult(other: PlanType): Boolean = this.canonicalized == other.canonicalized + + /** + * Returns a `hashCode` for the calculation performed by this plan. Unlike the standard + * `hashCode`, an attempt has been made to eliminate cosmetic differences. + */ + final def semanticHash(): Int = canonicalized.hashCode() /** * All the attributes that are used for this plan. */ lazy val allAttributes: AttributeSeq = children.flatMap(_.output) - - protected def cleanExpression(e: Expression): Expression = e match { - case a: Alias => - // As the root of the expression, Alias will always take an arbitrary exprId, we need - // to erase that for equality testing. - val cleanedExprId = - Alias(a.child, a.name)(ExprId(-1), a.qualifier, isGenerated = a.isGenerated) - BindReferences.bindReference(cleanedExprId, allAttributes, allowFailures = true) - case other => - BindReferences.bindReference(other, allAttributes, allowFailures = true) - } - - /** Args that have cleaned such that differences in expression id should not affect equality */ - protected lazy val cleanArgs: Seq[Any] = { - def cleanArg(arg: Any): Any = arg match { - // Children are checked using sameResult above. - case tn: TreeNode[_] if containsChild(tn) => null - case e: Expression => cleanExpression(e).canonicalized - case other => other - } - - mapProductIterator { - case s: Option[_] => s.map(cleanArg) - case s: Seq[_] => s.map(cleanArg) - case m: Map[_, _] => m.mapValues(cleanArg) - case other => cleanArg(other) - }.toSeq - } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala index b7177c4a2c4e..9cd5dfd21b16 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala @@ -67,14 +67,6 @@ case class LocalRelation(output: Seq[Attribute], data: Seq[InternalRow] = Nil) } } - override def sameResult(plan: LogicalPlan): Boolean = { - plan.canonicalized match { - case LocalRelation(otherOutput, otherData) => - otherOutput.map(_.dataType) == output.map(_.dataType) && otherData == data - case _ => false - } - } - override def computeStats(conf: SQLConf): Statistics = Statistics(sizeInBytes = output.map(n => BigInt(n.dataType.defaultSize)).sum * data.length) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 036b6256684c..6bdcf490ca5c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -143,8 +143,6 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { */ def childrenResolved: Boolean = children.forall(_.resolved) - override lazy val canonicalized: LogicalPlan = EliminateSubqueryAliases(this) - /** * Resolves a given schema to concrete [[Attribute]] references in this query plan. This function * should only be called on analyzed plans since it will throw [[AnalysisException]] for diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index c91de08ca5ef..3ad757ebba85 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -803,6 +803,8 @@ case class SubqueryAlias( child: LogicalPlan) extends UnaryNode { + override lazy val canonicalized: LogicalPlan = child.canonicalized + override def output: Seq[Attribute] = child.output.map(_.withQualifier(Some(alias))) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/broadcastMode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/broadcastMode.scala index 9dfdf4da78ff..2ab46dc8330a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/broadcastMode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/broadcastMode.scala @@ -26,10 +26,7 @@ import org.apache.spark.sql.catalyst.InternalRow trait BroadcastMode { def transform(rows: Array[InternalRow]): Any - /** - * Returns true iff this [[BroadcastMode]] generates the same result as `other`. - */ - def compatibleWith(other: BroadcastMode): Boolean + def canonicalized: BroadcastMode } /** @@ -39,7 +36,5 @@ case object IdentityBroadcastMode extends BroadcastMode { // TODO: pack the UnsafeRows into single bytes array. override def transform(rows: Array[InternalRow]): Array[InternalRow] = rows - override def compatibleWith(other: BroadcastMode): Boolean = { - this eq other - } + override def canonicalized: BroadcastMode = this } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index 2fa660c4d5e0..3a9132d74ac1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -119,7 +119,7 @@ case class RowDataSourceScanExec( val input = ctx.freshName("input") ctx.addMutableState("scala.collection.Iterator", input, s"$input = inputs[0];") val exprRows = output.zipWithIndex.map{ case (a, i) => - new BoundReference(i, a.dataType, a.nullable) + BoundReference(i, a.dataType, a.nullable) } val row = ctx.freshName("row") ctx.INPUT_ROW = row @@ -136,19 +136,17 @@ case class RowDataSourceScanExec( """.stripMargin } - // Ignore rdd when checking results - override def sameResult(plan: SparkPlan): Boolean = plan match { - case other: RowDataSourceScanExec => relation == other.relation && metadata == other.metadata - case _ => false - } + // Only care about `relation` and `metadata` when canonicalizing. + override def preCanonicalized: SparkPlan = + copy(rdd = null, outputPartitioning = null, metastoreTableIdentifier = None) } /** * Physical plan node for scanning data from HadoopFsRelations. * * @param relation The file-based relation to scan. - * @param output Output attributes of the scan. - * @param outputSchema Output schema of the scan. + * @param output Output attributes of the scan, including data attributes and partition attributes. + * @param requiredSchema Required schema of the underlying relation, excluding partition columns. * @param partitionFilters Predicates to use for partition pruning. * @param dataFilters Filters on non-partition columns. * @param metastoreTableIdentifier identifier for the table in the metastore. @@ -156,7 +154,7 @@ case class RowDataSourceScanExec( case class FileSourceScanExec( @transient relation: HadoopFsRelation, output: Seq[Attribute], - outputSchema: StructType, + requiredSchema: StructType, partitionFilters: Seq[Expression], dataFilters: Seq[Expression], override val metastoreTableIdentifier: Option[TableIdentifier]) @@ -267,7 +265,7 @@ case class FileSourceScanExec( val metadata = Map( "Format" -> relation.fileFormat.toString, - "ReadSchema" -> outputSchema.catalogString, + "ReadSchema" -> requiredSchema.catalogString, "Batched" -> supportsBatch.toString, "PartitionFilters" -> seqToString(partitionFilters), "PushedFilters" -> seqToString(pushedDownFilters), @@ -287,7 +285,7 @@ case class FileSourceScanExec( sparkSession = relation.sparkSession, dataSchema = relation.dataSchema, partitionSchema = relation.partitionSchema, - requiredSchema = outputSchema, + requiredSchema = requiredSchema, filters = pushedDownFilters, options = relation.options, hadoopConf = relation.sparkSession.sessionState.newHadoopConfWithOptions(relation.options)) @@ -515,14 +513,13 @@ case class FileSourceScanExec( } } - override def sameResult(plan: SparkPlan): Boolean = plan match { - case other: FileSourceScanExec => - val thisPredicates = partitionFilters.map(cleanExpression) - val otherPredicates = other.partitionFilters.map(cleanExpression) - val result = relation == other.relation && metadata == other.metadata && - thisPredicates.length == otherPredicates.length && - thisPredicates.zip(otherPredicates).forall(p => p._1.semanticEquals(p._2)) - result - case _ => false + override lazy val canonicalized: FileSourceScanExec = { + FileSourceScanExec( + relation, + output.map(normalizeExprId(_, output)), + requiredSchema, + partitionFilters.map(normalizeExprId(_, output)), + dataFilters.map(normalizeExprId(_, output)), + None) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index 2827b8ac0033..3d1b481a53e7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -87,13 +87,6 @@ case class ExternalRDD[T]( override def newInstance(): ExternalRDD.this.type = ExternalRDD(outputObjAttr.newInstance(), rdd)(session).asInstanceOf[this.type] - override def sameResult(plan: LogicalPlan): Boolean = { - plan.canonicalized match { - case ExternalRDD(_, otherRDD) => rdd.id == otherRDD.id - case _ => false - } - } - override protected def stringArgs: Iterator[Any] = Iterator(output) @transient override def computeStats(conf: SQLConf): Statistics = Statistics( @@ -162,13 +155,6 @@ case class LogicalRDD( )(session).asInstanceOf[this.type] } - override def sameResult(plan: LogicalPlan): Boolean = { - plan.canonicalized match { - case LogicalRDD(_, otherRDD, _, _) => rdd.id == otherRDD.id - case _ => false - } - } - override protected def stringArgs: Iterator[Any] = Iterator(output) @transient override def computeStats(conf: SQLConf): Statistics = Statistics( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala index e366b9af35c6..19c68c13262a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala @@ -33,7 +33,7 @@ case class LocalTableScanExec( override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) - private val unsafeRows: Array[InternalRow] = { + private lazy val unsafeRows: Array[InternalRow] = { if (rows.isEmpty) { Array.empty } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index 66a8e044ab87..44278e37c527 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -342,8 +342,9 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), "numGeneratedRows" -> SQLMetrics.createMetric(sparkContext, "number of generated rows")) - // output attributes should not affect the results - override lazy val cleanArgs: Seq[Any] = Seq(start, step, numSlices, numElements) + override lazy val canonicalized: SparkPlan = { + RangeExec(range.canonicalized.asInstanceOf[org.apache.spark.sql.catalyst.plans.logical.Range]) + } override def inputRDDs(): Seq[RDD[InternalRow]] = { sqlContext.sparkContext.parallelize(0 until numSlices, numSlices) @@ -607,11 +608,6 @@ case class SubqueryExec(name: String, child: SparkPlan) extends UnaryExecNode { override def outputOrdering: Seq[SortOrder] = child.outputOrdering - override def sameResult(o: SparkPlan): Boolean = o match { - case s: SubqueryExec => child.sameResult(s.child) - case _ => false - } - @transient private lazy val relationFuture: Future[Array[InternalRow]] = { // relationFuture is used in "doExecute". Therefore we can get the execution id correctly here. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala index 421520396007..3813f953e06a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala @@ -43,17 +43,8 @@ case class LogicalRelation( com.google.common.base.Objects.hashCode(relation, output) } - override def sameResult(otherPlan: LogicalPlan): Boolean = { - otherPlan.canonicalized match { - case LogicalRelation(otherRelation, _, _) => relation == otherRelation - case _ => false - } - } - - // When comparing two LogicalRelations from within LogicalPlan.sameResult, we only need - // LogicalRelation.cleanArgs to return Seq(relation), since expectedOutputAttribute's - // expId can be different but the relation is still the same. - override lazy val cleanArgs: Seq[Any] = Seq(relation) + // Only care about relation when canonicalizing. + override def preCanonicalized: LogicalPlan = copy(catalogTable = None) @transient override def computeStats(conf: SQLConf): Statistics = { catalogTable.flatMap(_.stats.map(_.toPlanStats(output))).getOrElse( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala index efcaca9338ad..9c859e41f876 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala @@ -48,10 +48,8 @@ case class BroadcastExchangeExec( override def outputPartitioning: Partitioning = BroadcastPartitioning(mode) - override def sameResult(plan: SparkPlan): Boolean = plan match { - case p: BroadcastExchangeExec => - mode.compatibleWith(p.mode) && child.sameResult(p.child) - case _ => false + override lazy val canonicalized: SparkPlan = { + BroadcastExchangeExec(mode.canonicalized, child.canonicalized) } @transient diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala index 9a9597d3733e..d993ea6c6cef 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala @@ -48,10 +48,8 @@ abstract class Exchange extends UnaryExecNode { case class ReusedExchangeExec(override val output: Seq[Attribute], child: Exchange) extends LeafExecNode { - override def sameResult(plan: SparkPlan): Boolean = { - // Ignore this wrapper. `plan` could also be a ReusedExchange, so we reverse the order here. - plan.sameResult(child) - } + // Ignore this wrapper for canonicalizing. + override lazy val canonicalized: SparkPlan = child.canonicalized def doExecute(): RDD[InternalRow] = { child.execute() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index b9f6601ea87f..2dd1dc3da96c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -829,15 +829,10 @@ private[execution] case class HashedRelationBroadcastMode(key: Seq[Expression]) extends BroadcastMode { override def transform(rows: Array[InternalRow]): HashedRelation = { - HashedRelation(rows.iterator, canonicalizedKey, rows.length) + HashedRelation(rows.iterator, canonicalized.key, rows.length) } - private lazy val canonicalizedKey: Seq[Expression] = { - key.map { e => e.canonicalized } - } - - override def compatibleWith(other: BroadcastMode): Boolean = other match { - case m: HashedRelationBroadcastMode => canonicalizedKey == m.canonicalizedKey - case _ => false + override lazy val canonicalized: HashedRelationBroadcastMode = { + this.copy(key = key.map(_.canonicalized)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala index 36cde3233dce..59eaf4d1c29b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala @@ -36,17 +36,17 @@ class ExchangeSuite extends SparkPlanTest with SharedSQLContext { ) } - test("compatible BroadcastMode") { + test("BroadcastMode.canonicalized") { val mode1 = IdentityBroadcastMode val mode2 = HashedRelationBroadcastMode(Literal(1L) :: Nil) val mode3 = HashedRelationBroadcastMode(Literal("s") :: Nil) - assert(mode1.compatibleWith(mode1)) - assert(!mode1.compatibleWith(mode2)) - assert(!mode2.compatibleWith(mode1)) - assert(mode2.compatibleWith(mode2)) - assert(!mode2.compatibleWith(mode3)) - assert(mode3.compatibleWith(mode3)) + assert(mode1.canonicalized == mode1.canonicalized) + assert(mode1.canonicalized != mode2.canonicalized) + assert(mode2.canonicalized != mode1.canonicalized) + assert(mode2.canonicalized == mode2.canonicalized) + assert(mode2.canonicalized != mode3.canonicalized) + assert(mode3.canonicalized == mode3.canonicalized) } test("BroadcastExchange same result") { @@ -70,7 +70,7 @@ class ExchangeSuite extends SparkPlanTest with SharedSQLContext { assert(!exchange1.sameResult(exchange2)) assert(!exchange2.sameResult(exchange3)) - assert(!exchange3.sameResult(exchange4)) + assert(exchange3.sameResult(exchange4)) assert(exchange4 sameResult exchange3) } @@ -98,7 +98,7 @@ class ExchangeSuite extends SparkPlanTest with SharedSQLContext { assert(exchange1 sameResult exchange2) assert(!exchange2.sameResult(exchange3)) assert(!exchange3.sameResult(exchange4)) - assert(!exchange4.sameResult(exchange5)) + assert(exchange4.sameResult(exchange5)) assert(exchange5 sameResult exchange4) } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala index 28f074849c0f..fab0d7fa8482 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala @@ -72,7 +72,7 @@ case class HiveTableScanExec( // Bind all partition key attribute references in the partition pruning predicate for later // evaluation. - private val boundPruningPred = partitionPruningPred.reduceLeftOption(And).map { pred => + private lazy val boundPruningPred = partitionPruningPred.reduceLeftOption(And).map { pred => require( pred.dataType == BooleanType, s"Data type of predicate $pred must be BooleanType rather than ${pred.dataType}.") @@ -80,20 +80,22 @@ case class HiveTableScanExec( BindReferences.bindReference(pred, relation.partitionCols) } - // Create a local copy of hadoopConf,so that scan specific modifications should not impact - // other queries - @transient private val hadoopConf = sparkSession.sessionState.newHadoopConf() - - @transient private val hiveQlTable = HiveClientImpl.toHiveTable(relation.tableMeta) - @transient private val tableDesc = new TableDesc( + @transient private lazy val hiveQlTable = HiveClientImpl.toHiveTable(relation.tableMeta) + @transient private lazy val tableDesc = new TableDesc( hiveQlTable.getInputFormatClass, hiveQlTable.getOutputFormatClass, hiveQlTable.getMetadata) - // append columns ids and names before broadcast - addColumnMetadataToConf(hadoopConf) + // Create a local copy of hadoopConf,so that scan specific modifications should not impact + // other queries + @transient private lazy val hadoopConf = { + val c = sparkSession.sessionState.newHadoopConf() + // append columns ids and names before broadcast + addColumnMetadataToConf(c) + c + } - @transient private val hadoopReader = new HadoopTableReader( + @transient private lazy val hadoopReader = new HadoopTableReader( output, relation.partitionCols, tableDesc, @@ -104,7 +106,7 @@ case class HiveTableScanExec( Cast(Literal(value), dataType).eval(null) } - private def addColumnMetadataToConf(hiveConf: Configuration) { + private def addColumnMetadataToConf(hiveConf: Configuration): Unit = { // Specifies needed column IDs for those non-partitioning columns. val columnOrdinals = AttributeMap(relation.dataCols.zipWithIndex) val neededColumnIDs = output.flatMap(columnOrdinals.get).map(o => o: Integer) @@ -198,18 +200,13 @@ case class HiveTableScanExec( } } - override def sameResult(plan: SparkPlan): Boolean = plan match { - case other: HiveTableScanExec => - val thisPredicates = partitionPruningPred.map(cleanExpression) - val otherPredicates = other.partitionPruningPred.map(cleanExpression) - - val result = relation.sameResult(other.relation) && - output.length == other.output.length && - output.zip(other.output) - .forall(p => p._1.name == p._2.name && p._1.dataType == p._2.dataType) && - thisPredicates.length == otherPredicates.length && - thisPredicates.zip(otherPredicates).forall(p => p._1.semanticEquals(p._2)) - result - case _ => false + override lazy val canonicalized: HiveTableScanExec = { + val input: AttributeSeq = relation.output + HiveTableScanExec( + requestedAttributes.map(normalizeExprId(_, input)), + relation.canonicalized.asInstanceOf[CatalogRelation], + partitionPruningPred.map(normalizeExprId(_, input)))(sparkSession) } + + override def otherCopyArgs: Seq[AnyRef] = Seq(sparkSession) } From 4f7d49b955b8c362da29a2540697240f4564d3ee Mon Sep 17 00:00:00 2001 From: Bogdan Raducanu Date: Mon, 10 Apr 2017 17:34:15 +0200 Subject: [PATCH 247/512] [SPARK-20243][TESTS] DebugFilesystem.assertNoOpenStreams thread race ## What changes were proposed in this pull request? Synchronize access to openStreams map. ## How was this patch tested? Existing tests. Author: Bogdan Raducanu Closes #17592 from bogdanrdc/SPARK-20243. --- .../org/apache/spark/DebugFilesystem.scala | 26 ++++++++++++------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/DebugFilesystem.scala b/core/src/test/scala/org/apache/spark/DebugFilesystem.scala index 72aea841117c..91355f736290 100644 --- a/core/src/test/scala/org/apache/spark/DebugFilesystem.scala +++ b/core/src/test/scala/org/apache/spark/DebugFilesystem.scala @@ -20,7 +20,6 @@ package org.apache.spark import java.io.{FileDescriptor, InputStream} import java.lang import java.nio.ByteBuffer -import java.util.concurrent.ConcurrentHashMap import scala.collection.JavaConverters._ import scala.collection.mutable @@ -31,21 +30,29 @@ import org.apache.spark.internal.Logging object DebugFilesystem extends Logging { // Stores the set of active streams and their creation sites. - private val openStreams = new ConcurrentHashMap[FSDataInputStream, Throwable]() + private val openStreams = mutable.Map.empty[FSDataInputStream, Throwable] - def clearOpenStreams(): Unit = { + def addOpenStream(stream: FSDataInputStream): Unit = openStreams.synchronized { + openStreams.put(stream, new Throwable()) + } + + def clearOpenStreams(): Unit = openStreams.synchronized { openStreams.clear() } - def assertNoOpenStreams(): Unit = { - val numOpen = openStreams.size() + def removeOpenStream(stream: FSDataInputStream): Unit = openStreams.synchronized { + openStreams.remove(stream) + } + + def assertNoOpenStreams(): Unit = openStreams.synchronized { + val numOpen = openStreams.values.size if (numOpen > 0) { - for (exc <- openStreams.values().asScala) { + for (exc <- openStreams.values) { logWarning("Leaked filesystem connection created at:") exc.printStackTrace() } throw new IllegalStateException(s"There are $numOpen possibly leaked file streams.", - openStreams.values().asScala.head) + openStreams.values.head) } } } @@ -60,8 +67,7 @@ class DebugFilesystem extends LocalFileSystem { override def open(f: Path, bufferSize: Int): FSDataInputStream = { val wrapped: FSDataInputStream = super.open(f, bufferSize) - openStreams.put(wrapped, new Throwable()) - + addOpenStream(wrapped) new FSDataInputStream(wrapped.getWrappedStream) { override def setDropBehind(dropBehind: lang.Boolean): Unit = wrapped.setDropBehind(dropBehind) @@ -98,7 +104,7 @@ class DebugFilesystem extends LocalFileSystem { override def close(): Unit = { wrapped.close() - openStreams.remove(wrapped) + removeOpenStream(wrapped) } override def read(): Int = wrapped.read() From 5acaf8c0c685e47ec619fbdfd353163721e1cf50 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Mon, 10 Apr 2017 17:45:27 +0200 Subject: [PATCH 248/512] [SPARK-19518][SQL] IGNORE NULLS in first / last in SQL ## What changes were proposed in this pull request? This PR proposes to add `IGNORE NULLS` keyword in `first`/`last` in Spark's parser likewise http://docs.oracle.com/cd/B19306_01/server.102/b14200/functions057.htm. This simply maps the keywords to existing `ignoreNullsExpr`. **Before** ```scala scala> sql("select first('a' IGNORE NULLS)").show() ``` ``` org.apache.spark.sql.catalyst.parser.ParseException: extraneous input 'NULLS' expecting {')', ','}(line 1, pos 24) == SQL == select first('a' IGNORE NULLS) ------------------------^^^ at org.apache.spark.sql.catalyst.parser.ParseException.withCommand(ParseDriver.scala:210) at org.apache.spark.sql.catalyst.parser.AbstractSqlParser.parse(ParseDriver.scala:112) at org.apache.spark.sql.execution.SparkSqlParser.parse(SparkSqlParser.scala:46) at org.apache.spark.sql.catalyst.parser.AbstractSqlParser.parsePlan(ParseDriver.scala:66) at org.apache.spark.sql.SparkSession.sql(SparkSession.scala:622) ... 48 elided ``` **After** ```scala scala> sql("select first('a' IGNORE NULLS)").show() ``` ``` +--------------+ |first(a, true)| +--------------+ | a| +--------------+ ``` ## How was this patch tested? Unit tests in `ExpressionParserSuite`. Author: hyukjinkwon Closes #17566 from HyukjinKwon/SPARK-19518. --- .../apache/spark/sql/catalyst/parser/SqlBase.g4 | 5 ++++- .../spark/sql/catalyst/parser/AstBuilder.scala | 17 +++++++++++++++++ .../catalyst/parser/ExpressionParserSuite.scala | 8 ++++++++ 3 files changed, 29 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 52b5b347fa9c..1ecb3d1958f4 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -552,6 +552,8 @@ primaryExpression | CASE whenClause+ (ELSE elseExpression=expression)? END #searchedCase | CASE value=expression whenClause+ (ELSE elseExpression=expression)? END #simpleCase | CAST '(' expression AS dataType ')' #cast + | FIRST '(' expression (IGNORE NULLS)? ')' #first + | LAST '(' expression (IGNORE NULLS)? ')' #last | constant #constantDefault | ASTERISK #star | qualifiedName '.' ASTERISK #star @@ -710,7 +712,7 @@ nonReserved | VIEW | REPLACE | IF | NO | DATA - | START | TRANSACTION | COMMIT | ROLLBACK + | START | TRANSACTION | COMMIT | ROLLBACK | IGNORE | SORT | CLUSTER | DISTRIBUTE | UNSET | TBLPROPERTIES | SKEWED | STORED | DIRECTORIES | LOCATION | EXCHANGE | ARCHIVE | UNARCHIVE | FILEFORMAT | TOUCH | COMPACT | CONCATENATE | CHANGE | CASCADE | RESTRICT | BUCKETS | CLUSTERED | SORTED | PURGE | INPUTFORMAT | OUTPUTFORMAT @@ -836,6 +838,7 @@ TRANSACTION: 'TRANSACTION'; COMMIT: 'COMMIT'; ROLLBACK: 'ROLLBACK'; MACRO: 'MACRO'; +IGNORE: 'IGNORE'; IF: 'IF'; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index fab7e4c5b128..c37255153802 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.{First, Last} import org.apache.spark.sql.catalyst.parser.SqlBaseParser._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ @@ -1022,6 +1023,22 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { Cast(expression(ctx.expression), visitSparkDataType(ctx.dataType)) } + /** + * Create a [[First]] expression. + */ + override def visitFirst(ctx: FirstContext): Expression = withOrigin(ctx) { + val ignoreNullsExpr = ctx.IGNORE != null + First(expression(ctx.expression), Literal(ignoreNullsExpr)).toAggregateExpression() + } + + /** + * Create a [[Last]] expression. + */ + override def visitLast(ctx: LastContext): Expression = withOrigin(ctx) { + val ignoreNullsExpr = ctx.IGNORE != null + Last(expression(ctx.expression), Literal(ignoreNullsExpr)).toAggregateExpression() + } + /** * Create a (windowed) Function expression. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala index d1c6b50536cd..e7f3b64a7113 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala @@ -21,6 +21,7 @@ import java.sql.{Date, Timestamp} import org.apache.spark.sql.catalyst.FunctionIdentifier import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, _} import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.{First, Last} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.CalendarInterval @@ -549,4 +550,11 @@ class ExpressionParserSuite extends PlanTest { val complexName2 = FunctionIdentifier("ba``r", Some("fo``o")) assertEqual(complexName2.quotedString, UnresolvedAttribute("fo``o.ba``r")) } + + test("SPARK-19526 Support ignore nulls keywords for first and last") { + assertEqual("first(a ignore nulls)", First('a, Literal(true)).toAggregateExpression()) + assertEqual("first(a)", First('a, Literal(false)).toAggregateExpression()) + assertEqual("last(a ignore nulls)", Last('a, Literal(true)).toAggregateExpression()) + assertEqual("last(a)", Last('a, Literal(false)).toAggregateExpression()) + } } From fd711ea13e558f0e7d3e01f08e01444d394499a6 Mon Sep 17 00:00:00 2001 From: Xiao Li Date: Mon, 10 Apr 2017 09:15:04 -0700 Subject: [PATCH 249/512] [SPARK-20273][SQL] Disallow Non-deterministic Filter push-down into Join Conditions ## What changes were proposed in this pull request? ``` sql("SELECT t1.b, rand(0) as r FROM cachedData, cachedData t1 GROUP BY t1.b having r > 0.5").show() ``` We will get the following error: ``` Job aborted due to stage failure: Task 1 in stage 4.0 failed 1 times, most recent failure: Lost task 1.0 in stage 4.0 (TID 8, localhost, executor driver): java.lang.NullPointerException at org.apache.spark.sql.catalyst.expressions.GeneratedClass$SpecificPredicate.eval(Unknown Source) at org.apache.spark.sql.execution.joins.BroadcastNestedLoopJoinExec$$anonfun$org$apache$spark$sql$execution$joins$BroadcastNestedLoopJoinExec$$boundCondition$1.apply(BroadcastNestedLoopJoinExec.scala:87) at org.apache.spark.sql.execution.joins.BroadcastNestedLoopJoinExec$$anonfun$org$apache$spark$sql$execution$joins$BroadcastNestedLoopJoinExec$$boundCondition$1.apply(BroadcastNestedLoopJoinExec.scala:87) at scala.collection.Iterator$$anon$13.hasNext(Iterator.scala:463) ``` Filters could be pushed down to the join conditions by the optimizer rule `PushPredicateThroughJoin`. However, Analyzer [blocks users to add non-deterministics conditions](https://github.com/apache/spark/blob/master/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala#L386-L395) (For details, see the PR https://github.com/apache/spark/pull/7535). We should not push down non-deterministic conditions; otherwise, we need to explicitly initialize the non-deterministic expressions. This PR is to simply block it. ### How was this patch tested? Added a test case Author: Xiao Li Closes #17585 from gatorsmile/joinRandCondition. --- .../spark/sql/catalyst/expressions/predicates.scala | 2 ++ .../sql/catalyst/optimizer/FilterPushdownSuite.scala | 10 ++++++++++ 2 files changed, 12 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 1235204591bb..8acb740f8db8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -90,6 +90,8 @@ trait PredicateHelper { * Returns true iff `expr` could be evaluated as a condition within join. */ protected def canEvaluateWithinJoin(expr: Expression): Boolean = expr match { + // Non-deterministic expressions are not allowed as join conditions. + case e if !e.deterministic => false case l: ListQuery => // A ListQuery defines the query which we want to search in an IN subquery expression. // Currently the only way to evaluate an IN subquery is to convert it to a diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index ccd0b7c5d7f7..950aa2379517 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -241,6 +241,16 @@ class FilterPushdownSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + test("joins: do not push down non-deterministic filters into join condition") { + val x = testRelation.subquery('x) + val y = testRelation1.subquery('y) + + val originalQuery = x.join(y).where(Rand(10) > 5.0).analyze + val optimized = Optimize.execute(originalQuery) + + comparePlans(optimized, originalQuery) + } + test("joins: push to one side after transformCondition") { val x = testRelation.subquery('x) val y = testRelation1.subquery('y) From a26e3ed5e414d0a350cfe65dd511b154868b9f1d Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Mon, 10 Apr 2017 20:11:56 +0100 Subject: [PATCH 250/512] [SPARK-20156][CORE][SQL][STREAMING][MLLIB] Java String toLowerCase "Turkish locale bug" causes Spark problems ## What changes were proposed in this pull request? Add Locale.ROOT to internal calls to String `toLowerCase`, `toUpperCase`, to avoid inadvertent locale-sensitive variation in behavior (aka the "Turkish locale problem"). The change looks large but it is just adding `Locale.ROOT` (the locale with no country or language specified) to every call to these methods. ## How was this patch tested? Existing tests. Author: Sean Owen Closes #17527 from srowen/SPARK-20156. --- .../apache/spark/network/util/JavaUtils.java | 5 ++- .../spark/network/util/TransportConf.java | 5 ++- .../spark/status/api/v1/TaskSorting.java | 3 +- .../scala/org/apache/spark/SparkContext.scala | 2 +- .../scala/org/apache/spark/SparkEnv.scala | 4 +- .../CoarseGrainedExecutorBackend.scala | 3 +- .../apache/spark/io/CompressionCodec.scala | 4 +- .../spark/metrics/sink/ConsoleSink.scala | 4 +- .../apache/spark/metrics/sink/CsvSink.scala | 2 +- .../spark/metrics/sink/GraphiteSink.scala | 6 +-- .../apache/spark/metrics/sink/Slf4jSink.scala | 4 +- .../scheduler/EventLoggingListener.scala | 3 +- .../spark/scheduler/SchedulableBuilder.scala | 5 ++- .../spark/scheduler/TaskSchedulerImpl.scala | 18 ++++---- .../spark/serializer/KryoSerializer.scala | 4 +- .../ui/exec/ExecutorThreadDumpPage.scala | 4 +- .../org/apache/spark/ui/jobs/JobPage.scala | 4 +- .../scala/org/apache/spark/ShuffleSuite.scala | 4 +- .../spark/broadcast/BroadcastSuite.scala | 4 +- .../internal/config/ConfigEntrySuite.scala | 3 +- .../BlockManagerReplicationSuite.scala | 6 ++- .../org/apache/spark/ui/StagePageSuite.scala | 5 ++- .../org/apache/spark/ui/UISeleniumSuite.scala | 5 ++- .../scala/org/apache/spark/ui/UISuite.scala | 11 ++--- .../examples/ml/DecisionTreeExample.scala | 4 +- .../apache/spark/examples/ml/GBTExample.scala | 4 +- .../examples/ml/RandomForestExample.scala | 4 +- .../spark/examples/mllib/LDAExample.scala | 4 +- .../sql/kafka010/KafkaSourceProvider.scala | 22 +++++----- .../sql/kafka010/KafkaRelationSuite.scala | 3 +- .../spark/sql/kafka010/KafkaSinkSuite.scala | 24 ++++++----- .../spark/sql/kafka010/KafkaSourceSuite.scala | 6 +-- .../streaming/kafka010/ConsumerStrategy.scala | 9 ++-- .../spark/streaming/kafka/KafkaUtils.scala | 4 +- .../classification/LogisticRegression.scala | 4 +- .../org/apache/spark/ml/clustering/LDA.scala | 5 ++- .../GeneralizedLinearRegressionWrapper.scala | 6 ++- .../apache/spark/ml/recommendation/ALS.scala | 9 ++-- .../GeneralizedLinearRegression.scala | 38 +++++++++-------- .../org/apache/spark/ml/tree/treeParams.scala | 41 ++++++++++++------- .../apache/spark/mllib/clustering/LDA.scala | 4 +- .../spark/mllib/tree/impurity/Impurity.scala | 4 +- .../scala/org/apache/spark/repl/Main.scala | 3 +- .../org/apache/spark/deploy/yarn/Client.scala | 4 +- .../sql/catalyst/analysis/ResolveHints.scala | 4 +- .../catalog/ExternalCatalogUtils.scala | 7 +++- .../sql/catalyst/catalog/SessionCatalog.scala | 3 +- .../catalyst/catalog/functionResources.scala | 4 +- .../sql/catalyst/expressions/Expression.scala | 4 +- .../expressions/mathExpressions.scala | 6 ++- .../expressions/regexpExpressions.scala | 3 +- .../expressions/windowExpressions.scala | 4 +- .../sql/catalyst/json/JacksonParser.scala | 5 ++- .../sql/catalyst/parser/AstBuilder.scala | 12 ++++-- .../spark/sql/catalyst/plans/joinTypes.scala | 4 +- .../streaming/InternalOutputModes.scala | 4 +- .../catalyst/util/CaseInsensitiveMap.scala | 9 ++-- .../sql/catalyst/util/CompressionCodecs.scala | 4 +- .../sql/catalyst/util/DateTimeUtils.scala | 4 +- .../spark/sql/catalyst/util/ParseMode.scala | 4 +- .../sql/catalyst/util/StringKeyHashMap.scala | 4 +- .../apache/spark/sql/internal/SQLConf.scala | 6 +-- .../org/apache/spark/sql/types/DataType.scala | 8 +++- .../apache/spark/sql/types/DecimalType.scala | 4 +- .../sql/streaming/JavaOutputModeSuite.java | 6 ++- .../sql/catalyst/analysis/AnalysisTest.scala | 5 ++- .../analysis/UnsupportedOperationsSuite.scala | 7 ++-- .../catalyst/expressions/ScalaUDFSuite.scala | 4 +- .../streaming/InternalOutputModesSuite.scala | 4 +- .../spark/sql/DataFrameNaFunctions.scala | 3 +- .../apache/spark/sql/DataFrameReader.scala | 4 +- .../apache/spark/sql/DataFrameWriter.scala | 6 +-- .../spark/sql/RelationalGroupedDataset.scala | 4 +- .../org/apache/spark/sql/api/r/SQLUtils.scala | 24 ++++++----- .../spark/sql/execution/SparkSqlParser.scala | 20 +++++---- .../sql/execution/WholeStageCodegenExec.scala | 6 ++- .../spark/sql/execution/command/ddl.scala | 6 ++- .../sql/execution/command/functions.scala | 4 +- .../execution/datasources/DataSource.scala | 16 ++++---- .../datasources/InMemoryFileIndex.scala | 1 - .../datasources/PartitioningUtils.scala | 4 +- .../datasources/csv/CSVOptions.scala | 4 +- .../spark/sql/execution/datasources/ddl.scala | 4 +- .../datasources/jdbc/JDBCOptions.scala | 6 +-- .../datasources/jdbc/JdbcUtils.scala | 3 +- .../datasources/parquet/ParquetOptions.scala | 8 +++- .../sql/execution/datasources/rules.scala | 5 ++- .../state/HDFSBackedStateStoreProvider.scala | 3 +- .../apache/spark/sql/internal/HiveSerDe.scala | 4 +- .../spark/sql/internal/SharedState.scala | 1 - .../sql/streaming/DataStreamReader.scala | 4 +- .../sql/streaming/DataStreamWriter.scala | 4 +- .../apache/spark/sql/JavaDatasetSuite.java | 2 +- .../apache/spark/sql/SQLQueryTestSuite.scala | 3 +- .../sql/execution/QueryExecutionSuite.scala | 13 +++--- .../execution/command/DDLCommandSuite.scala | 7 +++- .../sql/execution/command/DDLSuite.scala | 5 ++- .../datasources/parquet/ParquetIOSuite.scala | 4 +- .../ParquetPartitionDiscoverySuite.scala | 8 ++-- .../spark/sql/sources/FilteredScanSuite.scala | 9 ++-- .../sql/streaming/FileStreamSinkSuite.scala | 4 +- .../streaming/StreamingAggregationSuite.scala | 4 +- .../test/DataStreamReaderWriterSuite.scala | 7 ++-- .../sql/test/DataFrameReaderWriterSuite.scala | 10 +++-- .../hive/service/auth/HiveAuthFactory.java | 5 ++- .../org/apache/hive/service/auth/SaslQOP.java | 3 +- .../org/apache/hive/service/cli/Type.java | 3 +- .../hive/thriftserver/HiveThriftServer2.scala | 2 +- .../hive/thriftserver/SparkSQLCLIDriver.scala | 4 +- .../spark/sql/hive/HiveExternalCatalog.scala | 5 ++- .../spark/sql/hive/HiveSessionCatalog.scala | 4 +- .../spark/sql/hive/HiveStrategies.scala | 11 ++--- .../org/apache/spark/sql/hive/HiveUtils.scala | 3 +- .../sql/hive/client/HiveClientImpl.scala | 9 ++-- .../spark/sql/hive/client/HiveShim.scala | 6 +-- .../sql/hive/execution/HiveOptions.scala | 10 +++-- .../spark/sql/hive/orc/OrcOptions.scala | 6 ++- .../spark/sql/hive/HiveDDLCommandSuite.scala | 3 +- .../sql/hive/HiveSchemaInferenceSuite.scala | 3 -- .../hive/execution/HiveComparisonTest.scala | 8 ++-- .../sql/hive/execution/HiveQuerySuite.scala | 2 +- .../sql/hive/execution/SQLQuerySuite.scala | 7 ++-- .../streaming/dstream/InputDStream.scala | 6 ++- .../apache/spark/streaming/Java8APISuite.java | 5 ++- .../apache/spark/streaming/JavaAPISuite.java | 4 +- .../streaming/StreamingContextSuite.scala | 3 +- 126 files changed, 482 insertions(+), 299 deletions(-) diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/JavaUtils.java b/common/network-common/src/main/java/org/apache/spark/network/util/JavaUtils.java index 51d7fda0cb26..afc59efaef81 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/util/JavaUtils.java +++ b/common/network-common/src/main/java/org/apache/spark/network/util/JavaUtils.java @@ -24,6 +24,7 @@ import java.nio.ByteBuffer; import java.nio.channels.ReadableByteChannel; import java.nio.charset.StandardCharsets; +import java.util.Locale; import java.util.concurrent.TimeUnit; import java.util.regex.Matcher; import java.util.regex.Pattern; @@ -210,7 +211,7 @@ private static boolean isSymlink(File file) throws IOException { * The unit is also considered the default if the given string does not specify a unit. */ public static long timeStringAs(String str, TimeUnit unit) { - String lower = str.toLowerCase().trim(); + String lower = str.toLowerCase(Locale.ROOT).trim(); try { Matcher m = Pattern.compile("(-?[0-9]+)([a-z]+)?").matcher(lower); @@ -258,7 +259,7 @@ public static long timeStringAsSec(String str) { * provided, a direct conversion to the provided unit is attempted. */ public static long byteStringAs(String str, ByteUnit unit) { - String lower = str.toLowerCase().trim(); + String lower = str.toLowerCase(Locale.ROOT).trim(); try { Matcher m = Pattern.compile("([0-9]+)([a-z]+)?").matcher(lower); diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java b/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java index c226d8f3bc8f..a25078e262ef 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java +++ b/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java @@ -17,6 +17,7 @@ package org.apache.spark.network.util; +import java.util.Locale; import java.util.Properties; import com.google.common.primitives.Ints; @@ -75,7 +76,9 @@ public String getModuleName() { } /** IO mode: nio or epoll */ - public String ioMode() { return conf.get(SPARK_NETWORK_IO_MODE_KEY, "NIO").toUpperCase(); } + public String ioMode() { + return conf.get(SPARK_NETWORK_IO_MODE_KEY, "NIO").toUpperCase(Locale.ROOT); + } /** If true, we will prefer allocating off-heap byte buffers within Netty. */ public boolean preferDirectBufs() { diff --git a/core/src/main/java/org/apache/spark/status/api/v1/TaskSorting.java b/core/src/main/java/org/apache/spark/status/api/v1/TaskSorting.java index b38639e85481..dff4f5df6878 100644 --- a/core/src/main/java/org/apache/spark/status/api/v1/TaskSorting.java +++ b/core/src/main/java/org/apache/spark/status/api/v1/TaskSorting.java @@ -21,6 +21,7 @@ import java.util.Collections; import java.util.HashSet; +import java.util.Locale; import java.util.Set; public enum TaskSorting { @@ -35,7 +36,7 @@ public enum TaskSorting { } public static TaskSorting fromString(String str) { - String lower = str.toLowerCase(); + String lower = str.toLowerCase(Locale.ROOT); for (TaskSorting t: values()) { if (t.alternateNames.contains(lower)) { return t; diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 0225fd605607..99efc4893fda 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -361,7 +361,7 @@ class SparkContext(config: SparkConf) extends Logging { */ def setLogLevel(logLevel: String) { // let's allow lowercase or mixed case too - val upperCased = logLevel.toUpperCase(Locale.ENGLISH) + val upperCased = logLevel.toUpperCase(Locale.ROOT) require(SparkContext.VALID_LOG_LEVELS.contains(upperCased), s"Supplied level $logLevel did not match one of:" + s" ${SparkContext.VALID_LOG_LEVELS.mkString(",")}") diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 539dbb55eeff..f4a59f069a5f 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -19,6 +19,7 @@ package org.apache.spark import java.io.File import java.net.Socket +import java.util.Locale import scala.collection.mutable import scala.util.Properties @@ -319,7 +320,8 @@ object SparkEnv extends Logging { "sort" -> classOf[org.apache.spark.shuffle.sort.SortShuffleManager].getName, "tungsten-sort" -> classOf[org.apache.spark.shuffle.sort.SortShuffleManager].getName) val shuffleMgrName = conf.get("spark.shuffle.manager", "sort") - val shuffleMgrClass = shortShuffleMgrNames.getOrElse(shuffleMgrName.toLowerCase, shuffleMgrName) + val shuffleMgrClass = + shortShuffleMgrNames.getOrElse(shuffleMgrName.toLowerCase(Locale.ROOT), shuffleMgrName) val shuffleManager = instantiateClass[ShuffleManager](shuffleMgrClass) val useLegacyMemoryManager = conf.getBoolean("spark.memory.useLegacyMode", false) diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index ba0096d87456..b2b26ee107c0 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -19,6 +19,7 @@ package org.apache.spark.executor import java.net.URL import java.nio.ByteBuffer +import java.util.Locale import java.util.concurrent.atomic.AtomicBoolean import scala.collection.mutable @@ -72,7 +73,7 @@ private[spark] class CoarseGrainedExecutorBackend( def extractLogUrls: Map[String, String] = { val prefix = "SPARK_LOG_URL_" sys.env.filterKeys(_.startsWith(prefix)) - .map(e => (e._1.substring(prefix.length).toLowerCase, e._2)) + .map(e => (e._1.substring(prefix.length).toLowerCase(Locale.ROOT), e._2)) } override def receive: PartialFunction[Any, Unit] = { diff --git a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala index c216fe477fd1..0cb16f0627b7 100644 --- a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala +++ b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala @@ -18,6 +18,7 @@ package org.apache.spark.io import java.io._ +import java.util.Locale import com.ning.compress.lzf.{LZFInputStream, LZFOutputStream} import net.jpountz.lz4.LZ4BlockOutputStream @@ -66,7 +67,8 @@ private[spark] object CompressionCodec { } def createCodec(conf: SparkConf, codecName: String): CompressionCodec = { - val codecClass = shortCompressionCodecNames.getOrElse(codecName.toLowerCase, codecName) + val codecClass = + shortCompressionCodecNames.getOrElse(codecName.toLowerCase(Locale.ROOT), codecName) val codec = try { val ctor = Utils.classForName(codecClass).getConstructor(classOf[SparkConf]) Some(ctor.newInstance(conf).asInstanceOf[CompressionCodec]) diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/ConsoleSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/ConsoleSink.scala index 81b9056b40fb..fce556fd0382 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/ConsoleSink.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/ConsoleSink.scala @@ -17,7 +17,7 @@ package org.apache.spark.metrics.sink -import java.util.Properties +import java.util.{Locale, Properties} import java.util.concurrent.TimeUnit import com.codahale.metrics.{ConsoleReporter, MetricRegistry} @@ -39,7 +39,7 @@ private[spark] class ConsoleSink(val property: Properties, val registry: MetricR } val pollUnit: TimeUnit = Option(property.getProperty(CONSOLE_KEY_UNIT)) match { - case Some(s) => TimeUnit.valueOf(s.toUpperCase()) + case Some(s) => TimeUnit.valueOf(s.toUpperCase(Locale.ROOT)) case None => TimeUnit.valueOf(CONSOLE_DEFAULT_UNIT) } diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/CsvSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/CsvSink.scala index 9d5f2ae9328a..88bba2fdbd1c 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/CsvSink.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/CsvSink.scala @@ -42,7 +42,7 @@ private[spark] class CsvSink(val property: Properties, val registry: MetricRegis } val pollUnit: TimeUnit = Option(property.getProperty(CSV_KEY_UNIT)) match { - case Some(s) => TimeUnit.valueOf(s.toUpperCase()) + case Some(s) => TimeUnit.valueOf(s.toUpperCase(Locale.ROOT)) case None => TimeUnit.valueOf(CSV_DEFAULT_UNIT) } diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala index 22454e50b14b..23e31823f493 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala @@ -18,7 +18,7 @@ package org.apache.spark.metrics.sink import java.net.InetSocketAddress -import java.util.Properties +import java.util.{Locale, Properties} import java.util.concurrent.TimeUnit import com.codahale.metrics.MetricRegistry @@ -59,7 +59,7 @@ private[spark] class GraphiteSink(val property: Properties, val registry: Metric } val pollUnit: TimeUnit = propertyToOption(GRAPHITE_KEY_UNIT) match { - case Some(s) => TimeUnit.valueOf(s.toUpperCase()) + case Some(s) => TimeUnit.valueOf(s.toUpperCase(Locale.ROOT)) case None => TimeUnit.valueOf(GRAPHITE_DEFAULT_UNIT) } @@ -67,7 +67,7 @@ private[spark] class GraphiteSink(val property: Properties, val registry: Metric MetricsSystem.checkMinimalPollingPeriod(pollUnit, pollPeriod) - val graphite = propertyToOption(GRAPHITE_KEY_PROTOCOL).map(_.toLowerCase) match { + val graphite = propertyToOption(GRAPHITE_KEY_PROTOCOL).map(_.toLowerCase(Locale.ROOT)) match { case Some("udp") => new GraphiteUDP(new InetSocketAddress(host, port)) case Some("tcp") | None => new Graphite(new InetSocketAddress(host, port)) case Some(p) => throw new Exception(s"Invalid Graphite protocol: $p") diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/Slf4jSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/Slf4jSink.scala index 773e074336cb..7fa4ba762298 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/Slf4jSink.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/Slf4jSink.scala @@ -17,7 +17,7 @@ package org.apache.spark.metrics.sink -import java.util.Properties +import java.util.{Locale, Properties} import java.util.concurrent.TimeUnit import com.codahale.metrics.{MetricRegistry, Slf4jReporter} @@ -42,7 +42,7 @@ private[spark] class Slf4jSink( } val pollUnit: TimeUnit = Option(property.getProperty(SLF4J_KEY_UNIT)) match { - case Some(s) => TimeUnit.valueOf(s.toUpperCase()) + case Some(s) => TimeUnit.valueOf(s.toUpperCase(Locale.ROOT)) case None => TimeUnit.valueOf(SLF4J_DEFAULT_UNIT) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala index af9bdefc967e..aecb3a980e7c 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala @@ -20,6 +20,7 @@ package org.apache.spark.scheduler import java.io._ import java.net.URI import java.nio.charset.StandardCharsets +import java.util.Locale import scala.collection.mutable import scala.collection.mutable.ArrayBuffer @@ -316,7 +317,7 @@ private[spark] object EventLoggingListener extends Logging { } private def sanitize(str: String): String = { - str.replaceAll("[ :/]", "-").replaceAll("[.${}'\"]", "_").toLowerCase + str.replaceAll("[ :/]", "-").replaceAll("[.${}'\"]", "_").toLowerCase(Locale.ROOT) } /** diff --git a/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala b/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala index 20cedaf06042..417103436144 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala @@ -18,7 +18,7 @@ package org.apache.spark.scheduler import java.io.{FileInputStream, InputStream} -import java.util.{NoSuchElementException, Properties} +import java.util.{Locale, NoSuchElementException, Properties} import scala.util.control.NonFatal import scala.xml.{Node, XML} @@ -142,7 +142,8 @@ private[spark] class FairSchedulableBuilder(val rootPool: Pool, conf: SparkConf) defaultValue: SchedulingMode, fileName: String): SchedulingMode = { - val xmlSchedulingMode = (poolNode \ SCHEDULING_MODE_PROPERTY).text.trim.toUpperCase + val xmlSchedulingMode = + (poolNode \ SCHEDULING_MODE_PROPERTY).text.trim.toUpperCase(Locale.ROOT) val warningMessage = s"Unsupported schedulingMode: $xmlSchedulingMode found in " + s"Fair Scheduler configuration file: $fileName, using " + s"the default schedulingMode: $defaultValue for pool: $poolName" diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index 07aea773fa63..c849a16023a7 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -18,7 +18,7 @@ package org.apache.spark.scheduler import java.nio.ByteBuffer -import java.util.{Timer, TimerTask} +import java.util.{Locale, Timer, TimerTask} import java.util.concurrent.TimeUnit import java.util.concurrent.atomic.AtomicLong @@ -56,8 +56,7 @@ private[spark] class TaskSchedulerImpl private[scheduler]( val maxTaskFailures: Int, private[scheduler] val blacklistTrackerOpt: Option[BlacklistTracker], isLocal: Boolean = false) - extends TaskScheduler with Logging -{ + extends TaskScheduler with Logging { import TaskSchedulerImpl._ @@ -135,12 +134,13 @@ private[spark] class TaskSchedulerImpl private[scheduler]( private var schedulableBuilder: SchedulableBuilder = null // default scheduler is FIFO private val schedulingModeConf = conf.get(SCHEDULER_MODE_PROPERTY, SchedulingMode.FIFO.toString) - val schedulingMode: SchedulingMode = try { - SchedulingMode.withName(schedulingModeConf.toUpperCase) - } catch { - case e: java.util.NoSuchElementException => - throw new SparkException(s"Unrecognized $SCHEDULER_MODE_PROPERTY: $schedulingModeConf") - } + val schedulingMode: SchedulingMode = + try { + SchedulingMode.withName(schedulingModeConf.toUpperCase(Locale.ROOT)) + } catch { + case e: java.util.NoSuchElementException => + throw new SparkException(s"Unrecognized $SCHEDULER_MODE_PROPERTY: $schedulingModeConf") + } val rootPool: Pool = new Pool("", schedulingMode, 0, 0) diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index 6fc66e2374bd..e15166d11c24 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -19,6 +19,7 @@ package org.apache.spark.serializer import java.io._ import java.nio.ByteBuffer +import java.util.Locale import javax.annotation.Nullable import scala.collection.JavaConverters._ @@ -244,7 +245,8 @@ class KryoDeserializationStream( kryo.readClassAndObject(input).asInstanceOf[T] } catch { // DeserializationStream uses the EOF exception to indicate stopping condition. - case e: KryoException if e.getMessage.toLowerCase.contains("buffer underflow") => + case e: KryoException + if e.getMessage.toLowerCase(Locale.ROOT).contains("buffer underflow") => throw new EOFException } } diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala index dbcc6402bc30..6ce3f511e89c 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala @@ -17,6 +17,7 @@ package org.apache.spark.ui.exec +import java.util.Locale import javax.servlet.http.HttpServletRequest import scala.xml.{Node, Text} @@ -42,7 +43,8 @@ private[ui] class ExecutorThreadDumpPage(parent: ExecutorsTab) extends WebUIPage val v1 = if (threadTrace1.threadName.contains("Executor task launch")) 1 else 0 val v2 = if (threadTrace2.threadName.contains("Executor task launch")) 1 else 0 if (v1 == v2) { - threadTrace1.threadName.toLowerCase < threadTrace2.threadName.toLowerCase + threadTrace1.threadName.toLowerCase(Locale.ROOT) < + threadTrace2.threadName.toLowerCase(Locale.ROOT) } else { v1 > v2 } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala index 0ff9e5e9411c..3131c4a1eb7d 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala @@ -17,7 +17,7 @@ package org.apache.spark.ui.jobs -import java.util.Date +import java.util.{Date, Locale} import javax.servlet.http.HttpServletRequest import scala.collection.mutable.{Buffer, ListBuffer} @@ -77,7 +77,7 @@ private[ui] class JobPage(parent: JobsTab) extends WebUIPage("job") { | 'content': '
    // Put the block into one of the stores - val blockId = new TestBlockId( - "block-with-" + storageLevel.description.replace(" ", "-").toLowerCase) + val blockId = TestBlockId( + "block-with-" + storageLevel.description.replace(" ", "-").toLowerCase(Locale.ROOT)) val testValue = Array.fill[Byte](blockSize)(1) stores(0).putSingle(blockId, testValue, storageLevel) diff --git a/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala b/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala index 38030e066080..499d47b13d70 100644 --- a/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.ui +import java.util.Locale import javax.servlet.http.HttpServletRequest import scala.xml.Node @@ -37,14 +38,14 @@ class StagePageSuite extends SparkFunSuite with LocalSparkContext { test("peak execution memory should displayed") { val conf = new SparkConf(false) - val html = renderStagePage(conf).toString().toLowerCase + val html = renderStagePage(conf).toString().toLowerCase(Locale.ROOT) val targetString = "peak execution memory" assert(html.contains(targetString)) } test("SPARK-10543: peak execution memory should be per-task rather than cumulative") { val conf = new SparkConf(false) - val html = renderStagePage(conf).toString().toLowerCase + val html = renderStagePage(conf).toString().toLowerCase(Locale.ROOT) // verify min/25/50/75/max show task value not cumulative values assert(html.contains(s"$peakExecutionMemory.0 b" * 5)) } diff --git a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala index f4c561c73779..bdd148875e38 100644 --- a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.ui import java.net.{HttpURLConnection, URL} +import java.util.Locale import javax.servlet.http.{HttpServletRequest, HttpServletResponse} import scala.io.Source @@ -453,8 +454,8 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B eventually(timeout(10 seconds), interval(50 milliseconds)) { goToUi(sc, "/jobs") findAll(cssSelector("tbody tr a")).foreach { link => - link.text.toLowerCase should include ("count") - link.text.toLowerCase should not include "unknown" + link.text.toLowerCase(Locale.ROOT) should include ("count") + link.text.toLowerCase(Locale.ROOT) should not include "unknown" } } } diff --git a/core/src/test/scala/org/apache/spark/ui/UISuite.scala b/core/src/test/scala/org/apache/spark/ui/UISuite.scala index f1be0f6de3ce..0c3d4caeeabf 100644 --- a/core/src/test/scala/org/apache/spark/ui/UISuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/UISuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.ui import java.net.{BindException, ServerSocket} import java.net.{URI, URL} +import java.util.Locale import javax.servlet.http.{HttpServlet, HttpServletRequest, HttpServletResponse} import scala.io.Source @@ -72,10 +73,10 @@ class UISuite extends SparkFunSuite { eventually(timeout(10 seconds), interval(50 milliseconds)) { val html = Source.fromURL(sc.ui.get.webUrl).mkString assert(!html.contains("random data that should not be present")) - assert(html.toLowerCase.contains("stages")) - assert(html.toLowerCase.contains("storage")) - assert(html.toLowerCase.contains("environment")) - assert(html.toLowerCase.contains("executors")) + assert(html.toLowerCase(Locale.ROOT).contains("stages")) + assert(html.toLowerCase(Locale.ROOT).contains("storage")) + assert(html.toLowerCase(Locale.ROOT).contains("environment")) + assert(html.toLowerCase(Locale.ROOT).contains("executors")) } } } @@ -85,7 +86,7 @@ class UISuite extends SparkFunSuite { // test if visible from http://localhost:4040 eventually(timeout(10 seconds), interval(50 milliseconds)) { val html = Source.fromURL("http://localhost:4040").mkString - assert(html.toLowerCase.contains("stages")) + assert(html.toLowerCase(Locale.ROOT).contains("stages")) } } } diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala index 1745281c266c..f736ceed4436 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala @@ -18,6 +18,8 @@ // scalastyle:off println package org.apache.spark.examples.ml +import java.util.Locale + import scala.collection.mutable import scala.language.reflectiveCalls @@ -203,7 +205,7 @@ object DecisionTreeExample { .getOrCreate() params.checkpointDir.foreach(spark.sparkContext.setCheckpointDir) - val algo = params.algo.toLowerCase + val algo = params.algo.toLowerCase(Locale.ROOT) println(s"DecisionTreeExample with parameters:\n$params") diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/GBTExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/GBTExample.scala index db55298d8ea1..ed598d0d7dfa 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/GBTExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/GBTExample.scala @@ -18,6 +18,8 @@ // scalastyle:off println package org.apache.spark.examples.ml +import java.util.Locale + import scala.collection.mutable import scala.language.reflectiveCalls @@ -140,7 +142,7 @@ object GBTExample { .getOrCreate() params.checkpointDir.foreach(spark.sparkContext.setCheckpointDir) - val algo = params.algo.toLowerCase + val algo = params.algo.toLowerCase(Locale.ROOT) println(s"GBTExample with parameters:\n$params") diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestExample.scala index a9e07c0705c9..8fd46c37e298 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestExample.scala @@ -18,6 +18,8 @@ // scalastyle:off println package org.apache.spark.examples.ml +import java.util.Locale + import scala.collection.mutable import scala.language.reflectiveCalls @@ -146,7 +148,7 @@ object RandomForestExample { .getOrCreate() params.checkpointDir.foreach(spark.sparkContext.setCheckpointDir) - val algo = params.algo.toLowerCase + val algo = params.algo.toLowerCase(Locale.ROOT) println(s"RandomForestExample with parameters:\n$params") diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala index b923e627f209..cd77ecf990b3 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala @@ -18,6 +18,8 @@ // scalastyle:off println package org.apache.spark.examples.mllib +import java.util.Locale + import org.apache.log4j.{Level, Logger} import scopt.OptionParser @@ -131,7 +133,7 @@ object LDAExample { // Run LDA. val lda = new LDA() - val optimizer = params.algorithm.toLowerCase match { + val optimizer = params.algorithm.toLowerCase(Locale.ROOT) match { case "em" => new EMLDAOptimizer // add (1.0 / actualCorpusSize) to MiniBatchFraction be more robust on tiny datasets. case "online" => new OnlineLDAOptimizer().setMiniBatchFraction(0.05 + 1.0 / actualCorpusSize) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala index 58b52692b57c..ab1ce347cbe3 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.kafka010 import java.{util => ju} -import java.util.UUID +import java.util.{Locale, UUID} import scala.collection.JavaConverters._ @@ -74,11 +74,11 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister // id. Hence, we should generate a unique id for each query. val uniqueGroupId = s"spark-kafka-source-${UUID.randomUUID}-${metadataPath.hashCode}" - val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase, v) } + val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) } val specifiedKafkaParams = parameters .keySet - .filter(_.toLowerCase.startsWith("kafka.")) + .filter(_.toLowerCase(Locale.ROOT).startsWith("kafka.")) .map { k => k.drop(6).toString -> parameters(k) } .toMap @@ -115,11 +115,11 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister // partial data since Kafka will assign partitions to multiple consumers having the same group // id. Hence, we should generate a unique id for each query. val uniqueGroupId = s"spark-kafka-relation-${UUID.randomUUID}" - val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase, v) } + val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) } val specifiedKafkaParams = parameters .keySet - .filter(_.toLowerCase.startsWith("kafka.")) + .filter(_.toLowerCase(Locale.ROOT).startsWith("kafka.")) .map { k => k.drop(6).toString -> parameters(k) } .toMap @@ -192,7 +192,7 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister } private def kafkaParamsForProducer(parameters: Map[String, String]): Map[String, String] = { - val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase, v) } + val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) } if (caseInsensitiveParams.contains(s"kafka.${ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG}")) { throw new IllegalArgumentException( s"Kafka option '${ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG}' is not supported as keys " @@ -207,7 +207,7 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister } parameters .keySet - .filter(_.toLowerCase.startsWith("kafka.")) + .filter(_.toLowerCase(Locale.ROOT).startsWith("kafka.")) .map { k => k.drop(6).toString -> parameters(k) } .toMap + (ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG -> classOf[ByteArraySerializer].getName, ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG -> classOf[ByteArraySerializer].getName) @@ -272,7 +272,7 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister private def validateGeneralOptions(parameters: Map[String, String]): Unit = { // Validate source options - val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase, v) } + val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) } val specifiedStrategies = caseInsensitiveParams.filter { case (k, _) => STRATEGY_OPTION_KEYS.contains(k) }.toSeq @@ -451,8 +451,10 @@ private[kafka010] object KafkaSourceProvider { offsetOptionKey: String, defaultOffsets: KafkaOffsetRangeLimit): KafkaOffsetRangeLimit = { params.get(offsetOptionKey).map(_.trim) match { - case Some(offset) if offset.toLowerCase == "latest" => LatestOffsetRangeLimit - case Some(offset) if offset.toLowerCase == "earliest" => EarliestOffsetRangeLimit + case Some(offset) if offset.toLowerCase(Locale.ROOT) == "latest" => + LatestOffsetRangeLimit + case Some(offset) if offset.toLowerCase(Locale.ROOT) == "earliest" => + EarliestOffsetRangeLimit case Some(json) => SpecificOffsetRangeLimit(JsonUtils.partitionOffsets(json)) case None => defaultOffsets } diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaRelationSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaRelationSuite.scala index 68bc3e3e2e9a..91893df4ec32 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaRelationSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaRelationSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.kafka010 +import java.util.Locale import java.util.concurrent.atomic.AtomicInteger import org.apache.kafka.common.TopicPartition @@ -195,7 +196,7 @@ class KafkaRelationSuite extends QueryTest with BeforeAndAfter with SharedSQLCon reader.load() } expectedMsgs.foreach { m => - assert(ex.getMessage.toLowerCase.contains(m.toLowerCase)) + assert(ex.getMessage.toLowerCase(Locale.ROOT).contains(m.toLowerCase(Locale.ROOT))) } } diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala index 490535623cb3..4bd052d249ec 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.kafka010 +import java.util.Locale import java.util.concurrent.atomic.AtomicInteger import org.apache.kafka.clients.producer.ProducerConfig @@ -75,7 +76,7 @@ class KafkaSinkSuite extends StreamTest with SharedSQLContext { .option("kafka.bootstrap.servers", testUtils.brokerAddress) .save() } - assert(ex.getMessage.toLowerCase.contains( + assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( "null topic present in the data")) } @@ -92,7 +93,7 @@ class KafkaSinkSuite extends StreamTest with SharedSQLContext { .mode(SaveMode.Ignore) .save() } - assert(ex.getMessage.toLowerCase.contains( + assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( s"save mode ignore not allowed for kafka")) // Test bad save mode Overwrite @@ -103,7 +104,7 @@ class KafkaSinkSuite extends StreamTest with SharedSQLContext { .mode(SaveMode.Overwrite) .save() } - assert(ex.getMessage.toLowerCase.contains( + assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( s"save mode overwrite not allowed for kafka")) } @@ -233,7 +234,7 @@ class KafkaSinkSuite extends StreamTest with SharedSQLContext { writer.stop() } assert(ex.getMessage - .toLowerCase + .toLowerCase(Locale.ROOT) .contains("topic option required when no 'topic' attribute is present")) try { @@ -248,7 +249,8 @@ class KafkaSinkSuite extends StreamTest with SharedSQLContext { } finally { writer.stop() } - assert(ex.getMessage.toLowerCase.contains("required attribute 'value' not found")) + assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( + "required attribute 'value' not found")) } test("streaming - write data with valid schema but wrong types") { @@ -270,7 +272,7 @@ class KafkaSinkSuite extends StreamTest with SharedSQLContext { } finally { writer.stop() } - assert(ex.getMessage.toLowerCase.contains("topic type must be a string")) + assert(ex.getMessage.toLowerCase(Locale.ROOT).contains("topic type must be a string")) try { /* value field wrong type */ @@ -284,7 +286,7 @@ class KafkaSinkSuite extends StreamTest with SharedSQLContext { } finally { writer.stop() } - assert(ex.getMessage.toLowerCase.contains( + assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( "value attribute type must be a string or binarytype")) try { @@ -299,7 +301,7 @@ class KafkaSinkSuite extends StreamTest with SharedSQLContext { } finally { writer.stop() } - assert(ex.getMessage.toLowerCase.contains( + assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( "key attribute type must be a string or binarytype")) } @@ -318,7 +320,7 @@ class KafkaSinkSuite extends StreamTest with SharedSQLContext { } finally { writer.stop() } - assert(ex.getMessage.toLowerCase.contains("job aborted")) + assert(ex.getMessage.toLowerCase(Locale.ROOT).contains("job aborted")) } test("streaming - exception on config serializer") { @@ -330,7 +332,7 @@ class KafkaSinkSuite extends StreamTest with SharedSQLContext { input.toDF(), withOptions = Map("kafka.key.serializer" -> "foo"))() } - assert(ex.getMessage.toLowerCase.contains( + assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( "kafka option 'key.serializer' is not supported")) ex = intercept[IllegalArgumentException] { @@ -338,7 +340,7 @@ class KafkaSinkSuite extends StreamTest with SharedSQLContext { input.toDF(), withOptions = Map("kafka.value.serializer" -> "foo"))() } - assert(ex.getMessage.toLowerCase.contains( + assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( "kafka option 'value.serializer' is not supported")) } diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala index 0046ba7e43d1..2034b9be07f2 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.kafka010 import java.io._ import java.nio.charset.StandardCharsets.UTF_8 import java.nio.file.{Files, Paths} -import java.util.Properties +import java.util.{Locale, Properties} import java.util.concurrent.ConcurrentLinkedQueue import java.util.concurrent.atomic.AtomicInteger @@ -491,7 +491,7 @@ class KafkaSourceSuite extends KafkaSourceTest { reader.load() } expectedMsgs.foreach { m => - assert(ex.getMessage.toLowerCase.contains(m.toLowerCase)) + assert(ex.getMessage.toLowerCase(Locale.ROOT).contains(m.toLowerCase(Locale.ROOT))) } } @@ -524,7 +524,7 @@ class KafkaSourceSuite extends KafkaSourceTest { .option(s"$key", value) reader.load() } - assert(ex.getMessage.toLowerCase.contains("not supported")) + assert(ex.getMessage.toLowerCase(Locale.ROOT).contains("not supported")) } testUnsupportedConfig("kafka.group.id") diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/ConsumerStrategy.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/ConsumerStrategy.scala index 778c06ea16a2..d2100fc5a4ab 100644 --- a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/ConsumerStrategy.scala +++ b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/ConsumerStrategy.scala @@ -17,7 +17,8 @@ package org.apache.spark.streaming.kafka010 -import java.{ lang => jl, util => ju } +import java.{lang => jl, util => ju} +import java.util.Locale import scala.collection.JavaConverters._ @@ -93,7 +94,8 @@ private case class Subscribe[K, V]( // but cant seek to a position before poll, because poll is what gets subscription partitions // So, poll, suppress the first exception, then seek val aor = kafkaParams.get(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG) - val shouldSuppress = aor != null && aor.asInstanceOf[String].toUpperCase == "NONE" + val shouldSuppress = + aor != null && aor.asInstanceOf[String].toUpperCase(Locale.ROOT) == "NONE" try { consumer.poll(0) } catch { @@ -145,7 +147,8 @@ private case class SubscribePattern[K, V]( if (!toSeek.isEmpty) { // work around KAFKA-3370 when reset is none, see explanation in Subscribe above val aor = kafkaParams.get(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG) - val shouldSuppress = aor != null && aor.asInstanceOf[String].toUpperCase == "NONE" + val shouldSuppress = + aor != null && aor.asInstanceOf[String].toUpperCase(Locale.ROOT) == "NONE" try { consumer.poll(0) } catch { diff --git a/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala b/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala index d5aef8184fc8..78230725f322 100644 --- a/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala +++ b/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala @@ -20,7 +20,7 @@ package org.apache.spark.streaming.kafka import java.io.OutputStream import java.lang.{Integer => JInt, Long => JLong, Number => JNumber} import java.nio.charset.StandardCharsets -import java.util.{List => JList, Map => JMap, Set => JSet} +import java.util.{List => JList, Locale, Map => JMap, Set => JSet} import scala.collection.JavaConverters._ import scala.reflect.ClassTag @@ -206,7 +206,7 @@ object KafkaUtils { kafkaParams: Map[String, String], topics: Set[String] ): Map[TopicAndPartition, Long] = { - val reset = kafkaParams.get("auto.offset.reset").map(_.toLowerCase) + val reset = kafkaParams.get("auto.offset.reset").map(_.toLowerCase(Locale.ROOT)) val result = for { topicPartitions <- kc.getPartitions(topics).right leaderOffsets <- (if (reset == Some("smallest")) { diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 7b56bce41c32..965ce3d6f275 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -17,6 +17,8 @@ package org.apache.spark.ml.classification +import java.util.Locale + import scala.collection.mutable import breeze.linalg.{DenseVector => BDV} @@ -654,7 +656,7 @@ object LogisticRegression extends DefaultParamsReadable[LogisticRegression] { override def load(path: String): LogisticRegression = super.load(path) private[classification] val supportedFamilyNames = - Array("auto", "binomial", "multinomial").map(_.toLowerCase) + Array("auto", "binomial", "multinomial").map(_.toLowerCase(Locale.ROOT)) } /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala index 55720e2d613d..2f50dc7c85f3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala @@ -17,6 +17,8 @@ package org.apache.spark.ml.clustering +import java.util.Locale + import org.apache.hadoop.fs.Path import org.json4s.DefaultFormats import org.json4s.JsonAST.JObject @@ -173,7 +175,8 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM @Since("1.6.0") final val optimizer = new Param[String](this, "optimizer", "Optimizer or inference" + " algorithm used to estimate the LDA model. Supported: " + supportedOptimizers.mkString(", "), - (o: String) => ParamValidators.inArray(supportedOptimizers).apply(o.toLowerCase)) + (o: String) => + ParamValidators.inArray(supportedOptimizers).apply(o.toLowerCase(Locale.ROOT))) /** @group getParam */ @Since("1.6.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala index c49416b24018..4bd4aa7113f6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala @@ -17,6 +17,8 @@ package org.apache.spark.ml.r +import java.util.Locale + import org.apache.hadoop.fs.Path import org.json4s._ import org.json4s.JsonDSL._ @@ -91,7 +93,7 @@ private[r] object GeneralizedLinearRegressionWrapper .setRegParam(regParam) .setFeaturesCol(rFormula.getFeaturesCol) // set variancePower and linkPower if family is tweedie; otherwise, set link function - if (family.toLowerCase == "tweedie") { + if (family.toLowerCase(Locale.ROOT) == "tweedie") { glr.setVariancePower(variancePower).setLinkPower(linkPower) } else { glr.setLink(link) @@ -151,7 +153,7 @@ private[r] object GeneralizedLinearRegressionWrapper val rDeviance: Double = summary.deviance val rResidualDegreeOfFreedomNull: Long = summary.residualDegreeOfFreedomNull val rResidualDegreeOfFreedom: Long = summary.residualDegreeOfFreedom - val rAic: Double = if (family.toLowerCase == "tweedie" && + val rAic: Double = if (family.toLowerCase(Locale.ROOT) == "tweedie" && !Array(0.0, 1.0, 2.0).exists(x => math.abs(x - variancePower) < 1e-8)) { 0.0 } else { diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index 60dd7367053e..a20ef7244666 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -19,6 +19,7 @@ package org.apache.spark.ml.recommendation import java.{util => ju} import java.io.IOException +import java.util.Locale import scala.collection.mutable import scala.reflect.ClassTag @@ -40,8 +41,7 @@ import org.apache.spark.ml.util._ import org.apache.spark.mllib.linalg.CholeskyDecomposition import org.apache.spark.mllib.optimization.NNLS import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, Dataset, Row} -import org.apache.spark.sql.catalyst.encoders.RowEncoder +import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel @@ -118,10 +118,11 @@ private[recommendation] trait ALSModelParams extends Params with HasPredictionCo "useful in cross-validation or production scenarios, for handling user/item ids the model " + "has not seen in the training data. Supported values: " + s"${ALSModel.supportedColdStartStrategies.mkString(",")}.", - (s: String) => ALSModel.supportedColdStartStrategies.contains(s.toLowerCase)) + (s: String) => + ALSModel.supportedColdStartStrategies.contains(s.toLowerCase(Locale.ROOT))) /** @group expertGetParam */ - def getColdStartStrategy: String = $(coldStartStrategy).toLowerCase + def getColdStartStrategy: String = $(coldStartStrategy).toLowerCase(Locale.ROOT) } /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala index 3be8b533ee3f..33137b0c0fde 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala @@ -17,6 +17,8 @@ package org.apache.spark.ml.regression +import java.util.Locale + import breeze.stats.{distributions => dist} import org.apache.hadoop.fs.Path @@ -57,7 +59,7 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam final val family: Param[String] = new Param(this, "family", "The name of family which is a description of the error distribution to be used in the " + s"model. Supported options: ${supportedFamilyNames.mkString(", ")}.", - (value: String) => supportedFamilyNames.contains(value.toLowerCase)) + (value: String) => supportedFamilyNames.contains(value.toLowerCase(Locale.ROOT))) /** @group getParam */ @Since("2.0.0") @@ -99,7 +101,7 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam final val link: Param[String] = new Param(this, "link", "The name of link function " + "which provides the relationship between the linear predictor and the mean of the " + s"distribution function. Supported options: ${supportedLinkNames.mkString(", ")}", - (value: String) => supportedLinkNames.contains(value.toLowerCase)) + (value: String) => supportedLinkNames.contains(value.toLowerCase(Locale.ROOT))) /** @group getParam */ @Since("2.0.0") @@ -148,7 +150,7 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam schema: StructType, fitting: Boolean, featuresDataType: DataType): StructType = { - if ($(family).toLowerCase == "tweedie") { + if ($(family).toLowerCase(Locale.ROOT) == "tweedie") { if (isSet(link)) { logWarning("When family is tweedie, use param linkPower to specify link function. " + "Setting param link will take no effect.") @@ -460,13 +462,15 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine */ def apply(params: GeneralizedLinearRegressionBase): FamilyAndLink = { val familyObj = Family.fromParams(params) - val linkObj = if ((params.getFamily.toLowerCase != "tweedie" && - params.isSet(params.link)) || (params.getFamily.toLowerCase == "tweedie" && - params.isSet(params.linkPower))) { - Link.fromParams(params) - } else { - familyObj.defaultLink - } + val linkObj = + if ((params.getFamily.toLowerCase(Locale.ROOT) != "tweedie" && + params.isSet(params.link)) || + (params.getFamily.toLowerCase(Locale.ROOT) == "tweedie" && + params.isSet(params.linkPower))) { + Link.fromParams(params) + } else { + familyObj.defaultLink + } new FamilyAndLink(familyObj, linkObj) } } @@ -519,7 +523,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine * @param params the parameter map containing family name and variance power */ def fromParams(params: GeneralizedLinearRegressionBase): Family = { - params.getFamily.toLowerCase match { + params.getFamily.toLowerCase(Locale.ROOT) match { case Gaussian.name => Gaussian case Binomial.name => Binomial case Poisson.name => Poisson @@ -795,7 +799,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine * @param params the parameter map containing family, link and linkPower */ def fromParams(params: GeneralizedLinearRegressionBase): Link = { - if (params.getFamily.toLowerCase == "tweedie") { + if (params.getFamily.toLowerCase(Locale.ROOT) == "tweedie") { params.getLinkPower match { case 0.0 => Log case 1.0 => Identity @@ -804,7 +808,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine case others => new Power(others) } } else { - params.getLink.toLowerCase match { + params.getLink.toLowerCase(Locale.ROOT) match { case Identity.name => Identity case Logit.name => Logit case Log.name => Log @@ -1253,8 +1257,8 @@ class GeneralizedLinearRegressionSummary private[regression] ( */ @Since("2.0.0") lazy val dispersion: Double = if ( - model.getFamily.toLowerCase == Binomial.name || - model.getFamily.toLowerCase == Poisson.name) { + model.getFamily.toLowerCase(Locale.ROOT) == Binomial.name || + model.getFamily.toLowerCase(Locale.ROOT) == Poisson.name) { 1.0 } else { val rss = pearsonResiduals.agg(sum(pow(col("pearsonResiduals"), 2.0))).first().getDouble(0) @@ -1357,8 +1361,8 @@ class GeneralizedLinearRegressionTrainingSummary private[regression] ( @Since("2.0.0") lazy val pValues: Array[Double] = { if (isNormalSolver) { - if (model.getFamily.toLowerCase == Binomial.name || - model.getFamily.toLowerCase == Poisson.name) { + if (model.getFamily.toLowerCase(Locale.ROOT) == Binomial.name || + model.getFamily.toLowerCase(Locale.ROOT) == Poisson.name) { tValues.map { x => 2.0 * (1.0 - dist.Gaussian(0.0, 1.0).cdf(math.abs(x))) } } else { tValues.map { x => diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala index 5eb707dfe7bc..cd1950bd76c0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala @@ -17,6 +17,8 @@ package org.apache.spark.ml.tree +import java.util.Locale + import scala.util.Try import org.apache.spark.ml.PredictorParams @@ -218,7 +220,8 @@ private[ml] trait TreeClassifierParams extends Params { final val impurity: Param[String] = new Param[String](this, "impurity", "Criterion used for" + " information gain calculation (case-insensitive). Supported options:" + s" ${TreeClassifierParams.supportedImpurities.mkString(", ")}", - (value: String) => TreeClassifierParams.supportedImpurities.contains(value.toLowerCase)) + (value: String) => + TreeClassifierParams.supportedImpurities.contains(value.toLowerCase(Locale.ROOT))) setDefault(impurity -> "gini") @@ -230,7 +233,7 @@ private[ml] trait TreeClassifierParams extends Params { def setImpurity(value: String): this.type = set(impurity, value) /** @group getParam */ - final def getImpurity: String = $(impurity).toLowerCase + final def getImpurity: String = $(impurity).toLowerCase(Locale.ROOT) /** Convert new impurity to old impurity. */ private[ml] def getOldImpurity: OldImpurity = { @@ -247,7 +250,8 @@ private[ml] trait TreeClassifierParams extends Params { private[ml] object TreeClassifierParams { // These options should be lowercase. - final val supportedImpurities: Array[String] = Array("entropy", "gini").map(_.toLowerCase) + final val supportedImpurities: Array[String] = + Array("entropy", "gini").map(_.toLowerCase(Locale.ROOT)) } private[ml] trait DecisionTreeClassifierParams @@ -267,7 +271,8 @@ private[ml] trait TreeRegressorParams extends Params { final val impurity: Param[String] = new Param[String](this, "impurity", "Criterion used for" + " information gain calculation (case-insensitive). Supported options:" + s" ${TreeRegressorParams.supportedImpurities.mkString(", ")}", - (value: String) => TreeRegressorParams.supportedImpurities.contains(value.toLowerCase)) + (value: String) => + TreeRegressorParams.supportedImpurities.contains(value.toLowerCase(Locale.ROOT))) setDefault(impurity -> "variance") @@ -279,7 +284,7 @@ private[ml] trait TreeRegressorParams extends Params { def setImpurity(value: String): this.type = set(impurity, value) /** @group getParam */ - final def getImpurity: String = $(impurity).toLowerCase + final def getImpurity: String = $(impurity).toLowerCase(Locale.ROOT) /** Convert new impurity to old impurity. */ private[ml] def getOldImpurity: OldImpurity = { @@ -295,7 +300,8 @@ private[ml] trait TreeRegressorParams extends Params { private[ml] object TreeRegressorParams { // These options should be lowercase. - final val supportedImpurities: Array[String] = Array("variance").map(_.toLowerCase) + final val supportedImpurities: Array[String] = + Array("variance").map(_.toLowerCase(Locale.ROOT)) } private[ml] trait DecisionTreeRegressorParams extends DecisionTreeParams @@ -417,7 +423,8 @@ private[ml] trait RandomForestParams extends TreeEnsembleParams { s" Supported options: ${RandomForestParams.supportedFeatureSubsetStrategies.mkString(", ")}" + s", (0.0-1.0], [1-n].", (value: String) => - RandomForestParams.supportedFeatureSubsetStrategies.contains(value.toLowerCase) + RandomForestParams.supportedFeatureSubsetStrategies.contains( + value.toLowerCase(Locale.ROOT)) || Try(value.toInt).filter(_ > 0).isSuccess || Try(value.toDouble).filter(_ > 0).filter(_ <= 1.0).isSuccess) @@ -431,13 +438,13 @@ private[ml] trait RandomForestParams extends TreeEnsembleParams { def setFeatureSubsetStrategy(value: String): this.type = set(featureSubsetStrategy, value) /** @group getParam */ - final def getFeatureSubsetStrategy: String = $(featureSubsetStrategy).toLowerCase + final def getFeatureSubsetStrategy: String = $(featureSubsetStrategy).toLowerCase(Locale.ROOT) } private[spark] object RandomForestParams { // These options should be lowercase. final val supportedFeatureSubsetStrategies: Array[String] = - Array("auto", "all", "onethird", "sqrt", "log2").map(_.toLowerCase) + Array("auto", "all", "onethird", "sqrt", "log2").map(_.toLowerCase(Locale.ROOT)) } private[ml] trait RandomForestClassifierParams @@ -509,7 +516,8 @@ private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter { private[ml] object GBTClassifierParams { // The losses below should be lowercase. /** Accessor for supported loss settings: logistic */ - final val supportedLossTypes: Array[String] = Array("logistic").map(_.toLowerCase) + final val supportedLossTypes: Array[String] = + Array("logistic").map(_.toLowerCase(Locale.ROOT)) } private[ml] trait GBTClassifierParams extends GBTParams with TreeClassifierParams { @@ -523,12 +531,13 @@ private[ml] trait GBTClassifierParams extends GBTParams with TreeClassifierParam val lossType: Param[String] = new Param[String](this, "lossType", "Loss function which GBT" + " tries to minimize (case-insensitive). Supported options:" + s" ${GBTClassifierParams.supportedLossTypes.mkString(", ")}", - (value: String) => GBTClassifierParams.supportedLossTypes.contains(value.toLowerCase)) + (value: String) => + GBTClassifierParams.supportedLossTypes.contains(value.toLowerCase(Locale.ROOT))) setDefault(lossType -> "logistic") /** @group getParam */ - def getLossType: String = $(lossType).toLowerCase + def getLossType: String = $(lossType).toLowerCase(Locale.ROOT) /** (private[ml]) Convert new loss to old loss. */ override private[ml] def getOldLossType: OldClassificationLoss = { @@ -544,7 +553,8 @@ private[ml] trait GBTClassifierParams extends GBTParams with TreeClassifierParam private[ml] object GBTRegressorParams { // The losses below should be lowercase. /** Accessor for supported loss settings: squared (L2), absolute (L1) */ - final val supportedLossTypes: Array[String] = Array("squared", "absolute").map(_.toLowerCase) + final val supportedLossTypes: Array[String] = + Array("squared", "absolute").map(_.toLowerCase(Locale.ROOT)) } private[ml] trait GBTRegressorParams extends GBTParams with TreeRegressorParams { @@ -558,12 +568,13 @@ private[ml] trait GBTRegressorParams extends GBTParams with TreeRegressorParams val lossType: Param[String] = new Param[String](this, "lossType", "Loss function which GBT" + " tries to minimize (case-insensitive). Supported options:" + s" ${GBTRegressorParams.supportedLossTypes.mkString(", ")}", - (value: String) => GBTRegressorParams.supportedLossTypes.contains(value.toLowerCase)) + (value: String) => + GBTRegressorParams.supportedLossTypes.contains(value.toLowerCase(Locale.ROOT))) setDefault(lossType -> "squared") /** @group getParam */ - def getLossType: String = $(lossType).toLowerCase + def getLossType: String = $(lossType).toLowerCase(Locale.ROOT) /** (private[ml]) Convert new loss to old loss. */ override private[ml] def getOldLossType: OldLoss = { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala index 6c5f529fb8bf..4aa647236b31 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala @@ -17,6 +17,8 @@ package org.apache.spark.mllib.clustering +import java.util.Locale + import breeze.linalg.{DenseVector => BDV} import org.apache.spark.annotation.{DeveloperApi, Since} @@ -306,7 +308,7 @@ class LDA private ( @Since("1.4.0") def setOptimizer(optimizerName: String): this.type = { this.ldaOptimizer = - optimizerName.toLowerCase match { + optimizerName.toLowerCase(Locale.ROOT) match { case "em" => new EMLDAOptimizer case "online" => new OnlineLDAOptimizer case other => diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala index 98a3021461eb..4c7746869dde 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala @@ -17,6 +17,8 @@ package org.apache.spark.mllib.tree.impurity +import java.util.Locale + import org.apache.spark.annotation.{DeveloperApi, Since} /** @@ -184,7 +186,7 @@ private[spark] object ImpurityCalculator { * the given stats. */ def getCalculator(impurity: String, stats: Array[Double]): ImpurityCalculator = { - impurity.toLowerCase match { + impurity.toLowerCase(Locale.ROOT) match { case "gini" => new GiniCalculator(stats) case "entropy" => new EntropyCalculator(stats) case "variance" => new VarianceCalculator(stats) diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala index 7f2ec01cc967..39fc621de780 100644 --- a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala +++ b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala @@ -18,6 +18,7 @@ package org.apache.spark.repl import java.io.File +import java.util.Locale import scala.tools.nsc.GenericRunnerSettings @@ -88,7 +89,7 @@ object Main extends Logging { } val builder = SparkSession.builder.config(conf) - if (conf.get(CATALOG_IMPLEMENTATION.key, "hive").toLowerCase == "hive") { + if (conf.get(CATALOG_IMPLEMENTATION.key, "hive").toLowerCase(Locale.ROOT) == "hive") { if (SparkSession.hiveClassesArePresent) { // In the case that the property is not set at all, builder's config // does not have this value set to 'hive' yet. The original default diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 3218d221143e..424bbca12319 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -21,7 +21,7 @@ import java.io.{File, FileOutputStream, IOException, OutputStreamWriter} import java.net.{InetAddress, UnknownHostException, URI} import java.nio.ByteBuffer import java.nio.charset.StandardCharsets -import java.util.{Properties, UUID} +import java.util.{Locale, Properties, UUID} import java.util.zip.{ZipEntry, ZipOutputStream} import scala.collection.JavaConverters._ @@ -532,7 +532,7 @@ private[spark] class Client( try { jarsStream.setLevel(0) jarsDir.listFiles().foreach { f => - if (f.isFile && f.getName.toLowerCase().endsWith(".jar") && f.canRead) { + if (f.isFile && f.getName.toLowerCase(Locale.ROOT).endsWith(".jar") && f.canRead) { jarsStream.putNextEntry(new ZipEntry(f.getName)) Files.copy(f, jarsStream) jarsStream.closeEntry() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala index f8004ca300ac..c4827b81e8b6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.analysis +import java.util.Locale + import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.trees.CurrentOrigin @@ -83,7 +85,7 @@ object ResolveHints { } def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { - case h: Hint if BROADCAST_HINT_NAMES.contains(h.name.toUpperCase) => + case h: Hint if BROADCAST_HINT_NAMES.contains(h.name.toUpperCase(Locale.ROOT)) => applyBroadcastHint(h.child, h.parameters.toSet) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala index 254eedfe7751..3ca9e6a8da5b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.catalog import java.net.URI +import java.util.Locale import org.apache.hadoop.fs.Path import org.apache.hadoop.util.Shell @@ -167,8 +168,10 @@ object CatalogUtils { */ def maskCredentials(options: Map[String, String]): Map[String, String] = { options.map { - case (key, _) if key.toLowerCase == "password" => (key, "###") - case (key, value) if key.toLowerCase == "url" && value.toLowerCase.contains("password") => + case (key, _) if key.toLowerCase(Locale.ROOT) == "password" => (key, "###") + case (key, value) + if key.toLowerCase(Locale.ROOT) == "url" && + value.toLowerCase(Locale.ROOT).contains("password") => (key, "###") case o => o } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index 6f8c6ee2f0f4..faedf5f91c3e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.catalog import java.net.URI +import java.util.Locale import javax.annotation.concurrent.GuardedBy import scala.collection.mutable @@ -1098,7 +1099,7 @@ class SessionCatalog( name.database.isEmpty && functionRegistry.functionExists(name.funcName) && !FunctionRegistry.builtin.functionExists(name.funcName) && - !hiveFunctions.contains(name.funcName.toLowerCase) + !hiveFunctions.contains(name.funcName.toLowerCase(Locale.ROOT)) } protected def failFunctionLookup(name: String): Nothing = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/functionResources.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/functionResources.scala index 8e46b962ff43..67bf2d06c95d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/functionResources.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/functionResources.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.catalog +import java.util.Locale + import org.apache.spark.sql.AnalysisException /** A trait that represents the type of a resourced needed by a function. */ @@ -33,7 +35,7 @@ object ArchiveResource extends FunctionResourceType("archive") object FunctionResourceType { def fromString(resourceType: String): FunctionResourceType = { - resourceType.toLowerCase match { + resourceType.toLowerCase(Locale.ROOT) match { case "jar" => JarResource case "file" => FileResource case "archive" => ArchiveResource diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 1db26d9c415a..b847ef7bfaa9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.expressions +import java.util.Locale + import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ @@ -184,7 +186,7 @@ abstract class Expression extends TreeNode[Expression] { * Returns a user-facing string representation of this expression's name. * This should usually match the name of the function in SQL. */ - def prettyName: String = nodeName.toLowerCase + def prettyName: String = nodeName.toLowerCase(Locale.ROOT) protected def flatArguments: Iterator[Any] = productIterator.flatMap { case t: Traversable[_] => t diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala index dea5f85cb08c..c4d47ab2084f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import java.{lang => jl} +import java.util.Locale import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} @@ -68,7 +69,7 @@ abstract class UnaryMathExpression(val f: Double => Double, name: String) } // name of function in java.lang.Math - def funcName: String = name.toLowerCase + def funcName: String = name.toLowerCase(Locale.ROOT) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, c => s"java.lang.Math.${funcName}($c)") @@ -124,7 +125,8 @@ abstract class BinaryMathExpression(f: (Double, Double) => Double, name: String) } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.${name.toLowerCase}($c1, $c2)") + defineCodeGen(ctx, ev, (c1, c2) => + s"java.lang.Math.${name.toLowerCase(Locale.ROOT)}($c1, $c2)") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index b23da537be72..49b779711308 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.expressions +import java.util.Locale import java.util.regex.{MatchResult, Pattern} import org.apache.commons.lang3.StringEscapeUtils @@ -60,7 +61,7 @@ abstract class StringRegexExpression extends BinaryExpression } } - override def sql: String = s"${left.sql} ${prettyName.toUpperCase} ${right.sql}" + override def sql: String = s"${left.sql} ${prettyName.toUpperCase(Locale.ROOT)} ${right.sql}" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala index b2a3888ff7b0..37190429fc42 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.expressions +import java.util.Locale + import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedException} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} @@ -631,7 +633,7 @@ abstract class RankLike extends AggregateWindowFunction { override val updateExpressions = increaseRank +: increaseRowNumber +: children override val evaluateExpression: Expression = rank - override def sql: String = s"${prettyName.toUpperCase}()" + override def sql: String = s"${prettyName.toUpperCase(Locale.ROOT)}()" def withOrder(order: Seq[Expression]): RankLike } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala index fdb7d88d5bd7..ff6c93ae9815 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.json import java.io.ByteArrayOutputStream +import java.util.Locale import scala.collection.mutable.ArrayBuffer import scala.util.Try @@ -126,7 +127,7 @@ class JacksonParser( case VALUE_STRING => // Special case handling for NaN and Infinity. val value = parser.getText - val lowerCaseValue = value.toLowerCase + val lowerCaseValue = value.toLowerCase(Locale.ROOT) if (lowerCaseValue.equals("nan") || lowerCaseValue.equals("infinity") || lowerCaseValue.equals("-infinity") || @@ -146,7 +147,7 @@ class JacksonParser( case VALUE_STRING => // Special case handling for NaN and Infinity. val value = parser.getText - val lowerCaseValue = value.toLowerCase + val lowerCaseValue = value.toLowerCase(Locale.ROOT) if (lowerCaseValue.equals("nan") || lowerCaseValue.equals("infinity") || lowerCaseValue.equals("-infinity") || diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index c37255153802..e1db1ef5b869 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.parser import java.sql.{Date, Timestamp} +import java.util.Locale import javax.xml.bind.DatatypeConverter import scala.collection.JavaConverters._ @@ -1047,7 +1048,8 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { val name = ctx.qualifiedName.getText val isDistinct = Option(ctx.setQuantifier()).exists(_.DISTINCT != null) val arguments = ctx.namedExpression().asScala.map(expression) match { - case Seq(UnresolvedStar(None)) if name.toLowerCase == "count" && !isDistinct => + case Seq(UnresolvedStar(None)) + if name.toLowerCase(Locale.ROOT) == "count" && !isDistinct => // Transform COUNT(*) into COUNT(1). Seq(Literal(1)) case expressions => @@ -1271,7 +1273,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { */ override def visitTypeConstructor(ctx: TypeConstructorContext): Literal = withOrigin(ctx) { val value = string(ctx.STRING) - val valueType = ctx.identifier.getText.toUpperCase + val valueType = ctx.identifier.getText.toUpperCase(Locale.ROOT) try { valueType match { case "DATE" => @@ -1427,7 +1429,8 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { import ctx._ val s = value.getText try { - val interval = (unit.getText.toLowerCase, Option(to).map(_.getText.toLowerCase)) match { + val unitText = unit.getText.toLowerCase(Locale.ROOT) + val interval = (unitText, Option(to).map(_.getText.toLowerCase(Locale.ROOT))) match { case (u, None) if u.endsWith("s") => // Handle plural forms, e.g: yearS/monthS/weekS/dayS/hourS/minuteS/hourS/... CalendarInterval.fromSingleUnitString(u.substring(0, u.length - 1), s) @@ -1465,7 +1468,8 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { * Resolve/create a primitive type. */ override def visitPrimitiveDataType(ctx: PrimitiveDataTypeContext): DataType = withOrigin(ctx) { - (ctx.identifier.getText.toLowerCase, ctx.INTEGER_VALUE().asScala.toList) match { + val dataType = ctx.identifier.getText.toLowerCase(Locale.ROOT) + (dataType, ctx.INTEGER_VALUE().asScala.toList) match { case ("boolean", Nil) => BooleanType case ("tinyint" | "byte", Nil) => ByteType case ("smallint" | "short", Nil) => ShortType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala index 818f4e5ed2ae..90d11d6d9151 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala @@ -17,10 +17,12 @@ package org.apache.spark.sql.catalyst.plans +import java.util.Locale + import org.apache.spark.sql.catalyst.expressions.Attribute object JoinType { - def apply(typ: String): JoinType = typ.toLowerCase.replace("_", "") match { + def apply(typ: String): JoinType = typ.toLowerCase(Locale.ROOT).replace("_", "") match { case "inner" => Inner case "outer" | "full" | "fullouter" => FullOuter case "leftouter" | "left" => LeftOuter diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/streaming/InternalOutputModes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/streaming/InternalOutputModes.scala index bdf2baf7361d..3cd6970ebefb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/streaming/InternalOutputModes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/streaming/InternalOutputModes.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.streaming +import java.util.Locale + import org.apache.spark.sql.streaming.OutputMode /** @@ -47,7 +49,7 @@ private[sql] object InternalOutputModes { def apply(outputMode: String): OutputMode = { - outputMode.toLowerCase match { + outputMode.toLowerCase(Locale.ROOT) match { case "append" => OutputMode.Append case "complete" => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CaseInsensitiveMap.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CaseInsensitiveMap.scala index 66dd093bbb69..bb2c5926ae9b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CaseInsensitiveMap.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CaseInsensitiveMap.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.util +import java.util.Locale + /** * Builds a map in which keys are case insensitive. Input map can be accessed for cases where * case-sensitive information is required. The primary constructor is marked private to avoid @@ -26,11 +28,12 @@ package org.apache.spark.sql.catalyst.util class CaseInsensitiveMap[T] private (val originalMap: Map[String, T]) extends Map[String, T] with Serializable { - val keyLowerCasedMap = originalMap.map(kv => kv.copy(_1 = kv._1.toLowerCase)) + val keyLowerCasedMap = originalMap.map(kv => kv.copy(_1 = kv._1.toLowerCase(Locale.ROOT))) - override def get(k: String): Option[T] = keyLowerCasedMap.get(k.toLowerCase) + override def get(k: String): Option[T] = keyLowerCasedMap.get(k.toLowerCase(Locale.ROOT)) - override def contains(k: String): Boolean = keyLowerCasedMap.contains(k.toLowerCase) + override def contains(k: String): Boolean = + keyLowerCasedMap.contains(k.toLowerCase(Locale.ROOT)) override def +[B1 >: T](kv: (String, B1)): Map[String, B1] = { new CaseInsensitiveMap(originalMap + kv) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CompressionCodecs.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CompressionCodecs.scala index 435fba9d8851..1377a03d93b7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CompressionCodecs.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CompressionCodecs.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.util +import java.util.Locale + import org.apache.hadoop.conf.Configuration import org.apache.hadoop.io.SequenceFile.CompressionType import org.apache.hadoop.io.compress._ @@ -38,7 +40,7 @@ object CompressionCodecs { * If it is already a class name, just return it. */ def getCodecClassName(name: String): String = { - val codecName = shortCompressionCodecNames.getOrElse(name.toLowerCase, name) + val codecName = shortCompressionCodecNames.getOrElse(name.toLowerCase(Locale.ROOT), name) try { // Validate the codec name if (codecName != null) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index f614965520f4..eb6aad5b2d2b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -894,7 +894,7 @@ object DateTimeUtils { * (Because 1970-01-01 is Thursday). */ def getDayOfWeekFromString(string: UTF8String): Int = { - val dowString = string.toString.toUpperCase + val dowString = string.toString.toUpperCase(Locale.ROOT) dowString match { case "SU" | "SUN" | "SUNDAY" => 3 case "MO" | "MON" | "MONDAY" => 4 @@ -951,7 +951,7 @@ object DateTimeUtils { if (format == null) { TRUNC_INVALID } else { - format.toString.toUpperCase match { + format.toString.toUpperCase(Locale.ROOT) match { case "YEAR" | "YYYY" | "YY" => TRUNC_TO_YEAR case "MON" | "MONTH" | "MM" => TRUNC_TO_MONTH case _ => TRUNC_INVALID diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ParseMode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ParseMode.scala index 4565dbde88c8..2beb875d1751 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ParseMode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ParseMode.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.util +import java.util.Locale + import org.apache.spark.internal.Logging sealed trait ParseMode { @@ -45,7 +47,7 @@ object ParseMode extends Logging { /** * Returns the parse mode from the given string. */ - def fromString(mode: String): ParseMode = mode.toUpperCase match { + def fromString(mode: String): ParseMode = mode.toUpperCase(Locale.ROOT) match { case PermissiveMode.name => PermissiveMode case DropMalformedMode.name => DropMalformedMode case FailFastMode.name => FailFastMode diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringKeyHashMap.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringKeyHashMap.scala index a7ac6136835a..812d5ded4bf0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringKeyHashMap.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringKeyHashMap.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.util +import java.util.Locale + /** * Build a map with String type of key, and it also supports either key case * sensitive or insensitive. @@ -25,7 +27,7 @@ object StringKeyHashMap { def apply[T](caseSensitive: Boolean): StringKeyHashMap[T] = if (caseSensitive) { new StringKeyHashMap[T](identity) } else { - new StringKeyHashMap[T](_.toLowerCase) + new StringKeyHashMap[T](_.toLowerCase(Locale.ROOT)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 640c0f189c23..6b0f49503349 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.internal -import java.util.{NoSuchElementException, Properties, TimeZone} +import java.util.{Locale, NoSuchElementException, Properties, TimeZone} import java.util.concurrent.TimeUnit import scala.collection.JavaConverters._ @@ -243,7 +243,7 @@ object SQLConf { .doc("Sets the compression codec use when writing Parquet files. Acceptable values include: " + "uncompressed, snappy, gzip, lzo.") .stringConf - .transform(_.toLowerCase()) + .transform(_.toLowerCase(Locale.ROOT)) .checkValues(Set("uncompressed", "snappy", "gzip", "lzo")) .createWithDefault("snappy") @@ -324,7 +324,7 @@ object SQLConf { "properties) and NEVER_INFER (fallback to using the case-insensitive metastore schema " + "instead of inferring).") .stringConf - .transform(_.toUpperCase()) + .transform(_.toUpperCase(Locale.ROOT)) .checkValues(HiveCaseSensitiveInferenceMode.values.map(_.toString)) .createWithDefault(HiveCaseSensitiveInferenceMode.INFER_AND_SAVE.toString) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala index 26871259c6b6..520aff5e2b67 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.types +import java.util.Locale + import org.json4s._ import org.json4s.JsonAST.JValue import org.json4s.JsonDSL._ @@ -49,7 +51,9 @@ abstract class DataType extends AbstractDataType { /** Name of the type used in JSON serialization. */ def typeName: String = { - this.getClass.getSimpleName.stripSuffix("$").stripSuffix("Type").stripSuffix("UDT").toLowerCase + this.getClass.getSimpleName + .stripSuffix("$").stripSuffix("Type").stripSuffix("UDT") + .toLowerCase(Locale.ROOT) } private[sql] def jsonValue: JValue = typeName @@ -69,7 +73,7 @@ abstract class DataType extends AbstractDataType { /** Readable string representation for the type with truncation */ private[sql] def simpleString(maxNumberFields: Int): String = simpleString - def sql: String = simpleString.toUpperCase + def sql: String = simpleString.toUpperCase(Locale.ROOT) /** * Check if `this` and `other` are the same data type when ignoring nullability diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala index 4dc06fc9cf09..5c4bc5e33c53 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.types +import java.util.Locale + import scala.reflect.runtime.universe.typeTag import org.apache.spark.annotation.InterfaceStability @@ -65,7 +67,7 @@ case class DecimalType(precision: Int, scale: Int) extends FractionalType { override def toString: String = s"DecimalType($precision,$scale)" - override def sql: String = typeName.toUpperCase + override def sql: String = typeName.toUpperCase(Locale.ROOT) /** * Returns whether this DecimalType is wider than `other`. If yes, it means `other` diff --git a/sql/catalyst/src/test/java/org/apache/spark/sql/streaming/JavaOutputModeSuite.java b/sql/catalyst/src/test/java/org/apache/spark/sql/streaming/JavaOutputModeSuite.java index e0a54fe30ac7..d8845e0c838f 100644 --- a/sql/catalyst/src/test/java/org/apache/spark/sql/streaming/JavaOutputModeSuite.java +++ b/sql/catalyst/src/test/java/org/apache/spark/sql/streaming/JavaOutputModeSuite.java @@ -17,6 +17,8 @@ package org.apache.spark.sql.streaming; +import java.util.Locale; + import org.junit.Test; public class JavaOutputModeSuite { @@ -24,8 +26,8 @@ public class JavaOutputModeSuite { @Test public void testOutputModes() { OutputMode o1 = OutputMode.Append(); - assert(o1.toString().toLowerCase().contains("append")); + assert(o1.toString().toLowerCase(Locale.ROOT).contains("append")); OutputMode o2 = OutputMode.Complete(); - assert (o2.toString().toLowerCase().contains("complete")); + assert (o2.toString().toLowerCase(Locale.ROOT).contains("complete")); } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala index 1be25ec06c74..82015b1e0671 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.analysis +import java.util.Locale + import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog} import org.apache.spark.sql.catalyst.plans.PlanTest @@ -79,7 +81,8 @@ trait AnalysisTest extends PlanTest { analyzer.checkAnalysis(analyzer.execute(inputPlan)) } - if (!expectedErrors.map(_.toLowerCase).forall(e.getMessage.toLowerCase.contains)) { + if (!expectedErrors.map(_.toLowerCase(Locale.ROOT)).forall( + e.getMessage.toLowerCase(Locale.ROOT).contains)) { fail( s"""Exception message should contain the following substrings: | diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala index 8f0a0c0d99d1..c39e372c272b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala @@ -17,19 +17,20 @@ package org.apache.spark.sql.catalyst.analysis +import java.util.Locale + import org.apache.spark.SparkFunSuite import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Literal, NamedExpression} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, NamedExpression} import org.apache.spark.sql.catalyst.expressions.aggregate.Count import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.{FlatMapGroupsWithState, _} import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.{IntegerType, LongType, MetadataBuilder} -import org.apache.spark.unsafe.types.CalendarInterval /** A dummy command for testing unsupported operations. */ case class DummyCommand() extends Command @@ -696,7 +697,7 @@ class UnsupportedOperationsSuite extends SparkFunSuite { testBody } expectedMsgs.foreach { m => - if (!e.getMessage.toLowerCase.contains(m.toLowerCase)) { + if (!e.getMessage.toLowerCase(Locale.ROOT).contains(m.toLowerCase(Locale.ROOT))) { fail(s"Exception message should contain: '$m', " + s"actual exception message:\n\t'${e.getMessage}'") } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala index 7e45028653e3..13bd363c8b69 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.expressions +import java.util.Locale + import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.sql.types.{IntegerType, StringType} @@ -32,7 +34,7 @@ class ScalaUDFSuite extends SparkFunSuite with ExpressionEvalHelper { test("better error message for NPE") { val udf = ScalaUDF( - (s: String) => s.toLowerCase, + (s: String) => s.toLowerCase(Locale.ROOT), StringType, Literal.create(null, StringType) :: Nil) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/streaming/InternalOutputModesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/streaming/InternalOutputModesSuite.scala index 201dac35ed2d..3159b541dca7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/streaming/InternalOutputModesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/streaming/InternalOutputModesSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.streaming +import java.util.Locale + import org.apache.spark.SparkFunSuite import org.apache.spark.sql.streaming.OutputMode @@ -40,7 +42,7 @@ class InternalOutputModesSuite extends SparkFunSuite { val acceptedModes = Seq("append", "update", "complete") val e = intercept[IllegalArgumentException](InternalOutputModes(outputMode)) (Seq("output mode", "unknown", outputMode) ++ acceptedModes).foreach { s => - assert(e.getMessage.toLowerCase.contains(s.toLowerCase)) + assert(e.getMessage.toLowerCase(Locale.ROOT).contains(s.toLowerCase(Locale.ROOT))) } } testMode("Xyz") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala index d8f953fba5a8..93d565d9fe90 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql import java.{lang => jl} +import java.util.Locale import scala.collection.JavaConverters._ @@ -89,7 +90,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { * @since 1.3.1 */ def drop(how: String, cols: Seq[String]): DataFrame = { - how.toLowerCase match { + how.toLowerCase(Locale.ROOT) match { case "any" => drop(cols.size, cols) case "all" => drop(1, cols) case _ => throw new IllegalArgumentException(s"how ($how) must be 'any' or 'all'") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 2b8537c3d4a6..49691c15d0f7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql -import java.util.Properties +import java.util.{Locale, Properties} import scala.collection.JavaConverters._ @@ -164,7 +164,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { */ @scala.annotation.varargs def load(paths: String*): DataFrame = { - if (source.toLowerCase == DDLUtils.HIVE_PROVIDER) { + if (source.toLowerCase(Locale.ROOT) == DDLUtils.HIVE_PROVIDER) { throw new AnalysisException("Hive data source can only be used with tables, you can not " + "read files of Hive data source directly.") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 338a6e1314d9..1732a8e08b73 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql -import java.util.Properties +import java.util.{Locale, Properties} import scala.collection.JavaConverters._ @@ -66,7 +66,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { * @since 1.4.0 */ def mode(saveMode: String): DataFrameWriter[T] = { - this.mode = saveMode.toLowerCase match { + this.mode = saveMode.toLowerCase(Locale.ROOT) match { case "overwrite" => SaveMode.Overwrite case "append" => SaveMode.Append case "ignore" => SaveMode.Ignore @@ -223,7 +223,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { * @since 1.4.0 */ def save(): Unit = { - if (source.toLowerCase == DDLUtils.HIVE_PROVIDER) { + if (source.toLowerCase(Locale.ROOT) == DDLUtils.HIVE_PROVIDER) { throw new AnalysisException("Hive data source can only be used with tables, you can not " + "write files of Hive data source directly.") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index 0fe8d87ebd6b..64755434784a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql +import java.util.Locale + import scala.collection.JavaConverters._ import scala.language.implicitConversions @@ -108,7 +110,7 @@ class RelationalGroupedDataset protected[sql]( private[this] def strToExpr(expr: String): (Expression => Expression) = { val exprToFunc: (Expression => Expression) = { - (inputExpr: Expression) => expr.toLowerCase match { + (inputExpr: Expression) => expr.toLowerCase(Locale.ROOT) match { // We special handle a few cases that have alias that are not in function registry. case "avg" | "average" | "mean" => UnresolvedFunction("avg", inputExpr :: Nil, isDistinct = false) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala index c77328690dae..a26d00411fba 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.api.r import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} -import java.util.{Map => JMap} +import java.util.{Locale, Map => JMap} import scala.collection.JavaConverters._ import scala.util.matching.Regex @@ -47,17 +47,19 @@ private[sql] object SQLUtils extends Logging { jsc: JavaSparkContext, sparkConfigMap: JMap[Object, Object], enableHiveSupport: Boolean): SparkSession = { - val spark = if (SparkSession.hiveClassesArePresent && enableHiveSupport - && jsc.sc.conf.get(CATALOG_IMPLEMENTATION.key, "hive").toLowerCase == "hive") { - SparkSession.builder().sparkContext(withHiveExternalCatalog(jsc.sc)).getOrCreate() - } else { - if (enableHiveSupport) { - logWarning("SparkR: enableHiveSupport is requested for SparkSession but " + - s"Spark is not built with Hive or ${CATALOG_IMPLEMENTATION.key} is not set to 'hive', " + - "falling back to without Hive support.") + val spark = + if (SparkSession.hiveClassesArePresent && enableHiveSupport && + jsc.sc.conf.get(CATALOG_IMPLEMENTATION.key, "hive").toLowerCase(Locale.ROOT) == + "hive") { + SparkSession.builder().sparkContext(withHiveExternalCatalog(jsc.sc)).getOrCreate() + } else { + if (enableHiveSupport) { + logWarning("SparkR: enableHiveSupport is requested for SparkSession but " + + s"Spark is not built with Hive or ${CATALOG_IMPLEMENTATION.key} is not set to " + + "'hive', falling back to without Hive support.") + } + SparkSession.builder().sparkContext(jsc.sc).getOrCreate() } - SparkSession.builder().sparkContext(jsc.sc).getOrCreate() - } setSparkContextSessionConf(spark, sparkConfigMap) spark } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index 80afb59b3e88..20dacf88504f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution +import java.util.Locale + import scala.collection.JavaConverters._ import org.antlr.v4.runtime.{ParserRuleContext, Token} @@ -103,7 +105,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { logWarning(s"Partition specification is ignored: ${ctx.partitionSpec.getText}") } if (ctx.identifier != null) { - if (ctx.identifier.getText.toLowerCase != "noscan") { + if (ctx.identifier.getText.toLowerCase(Locale.ROOT) != "noscan") { throw new ParseException(s"Expected `NOSCAN` instead of `${ctx.identifier.getText}`", ctx) } AnalyzeTableCommand(visitTableIdentifier(ctx.tableIdentifier)) @@ -563,7 +565,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { } else if (value.STRING != null) { string(value.STRING) } else if (value.booleanValue != null) { - value.getText.toLowerCase + value.getText.toLowerCase(Locale.ROOT) } else { value.getText } @@ -647,7 +649,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { */ override def visitShowFunctions(ctx: ShowFunctionsContext): LogicalPlan = withOrigin(ctx) { import ctx._ - val (user, system) = Option(ctx.identifier).map(_.getText.toLowerCase) match { + val (user, system) = Option(ctx.identifier).map(_.getText.toLowerCase(Locale.ROOT)) match { case None | Some("all") => (true, true) case Some("system") => (false, true) case Some("user") => (true, false) @@ -677,7 +679,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { */ override def visitCreateFunction(ctx: CreateFunctionContext): LogicalPlan = withOrigin(ctx) { val resources = ctx.resource.asScala.map { resource => - val resourceType = resource.identifier.getText.toLowerCase + val resourceType = resource.identifier.getText.toLowerCase(Locale.ROOT) resourceType match { case "jar" | "file" | "archive" => FunctionResource(FunctionResourceType.fromString(resourceType), string(resource.STRING)) @@ -959,7 +961,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { .flatMap(_.orderedIdentifier.asScala) .map { orderedIdCtx => Option(orderedIdCtx.ordering).map(_.getText).foreach { dir => - if (dir.toLowerCase != "asc") { + if (dir.toLowerCase(Locale.ROOT) != "asc") { operationNotAllowed(s"Column ordering must be ASC, was '$dir'", ctx) } } @@ -1012,13 +1014,13 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { val mayebePaths = remainder(ctx.identifier).trim ctx.op.getType match { case SqlBaseParser.ADD => - ctx.identifier.getText.toLowerCase match { + ctx.identifier.getText.toLowerCase(Locale.ROOT) match { case "file" => AddFileCommand(mayebePaths) case "jar" => AddJarCommand(mayebePaths) case other => operationNotAllowed(s"ADD with resource type '$other'", ctx) } case SqlBaseParser.LIST => - ctx.identifier.getText.toLowerCase match { + ctx.identifier.getText.toLowerCase(Locale.ROOT) match { case "files" | "file" => if (mayebePaths.length > 0) { ListFilesCommand(mayebePaths.split("\\s+")) @@ -1305,7 +1307,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { (rowFormatCtx, createFileFormatCtx.fileFormat) match { case (_, ffTable: TableFileFormatContext) => // OK case (rfSerde: RowFormatSerdeContext, ffGeneric: GenericFileFormatContext) => - ffGeneric.identifier.getText.toLowerCase match { + ffGeneric.identifier.getText.toLowerCase(Locale.ROOT) match { case ("sequencefile" | "textfile" | "rcfile") => // OK case fmt => operationNotAllowed( @@ -1313,7 +1315,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { parentCtx) } case (rfDelimited: RowFormatDelimitedContext, ffGeneric: GenericFileFormatContext) => - ffGeneric.identifier.getText.toLowerCase match { + ffGeneric.identifier.getText.toLowerCase(Locale.ROOT) match { case "textfile" => // OK case fmt => operationNotAllowed( s"ROW FORMAT DELIMITED is only compatible with 'textfile', not '$fmt'", parentCtx) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index c31fd92447c0..c1e1a631c677 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -17,7 +17,9 @@ package org.apache.spark.sql.execution -import org.apache.spark.{broadcast, TaskContext} +import java.util.Locale + +import org.apache.spark.broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ @@ -43,7 +45,7 @@ trait CodegenSupport extends SparkPlan { case _: SortMergeJoinExec => "smj" case _: RDDScanExec => "rdd" case _: DataSourceScanExec => "scan" - case _ => nodeName.toLowerCase + case _ => nodeName.toLowerCase(Locale.ROOT) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala index 9d3c55060dfb..55540563ef91 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.command +import java.util.Locale + import scala.collection.{GenMap, GenSeq} import scala.collection.parallel.ForkJoinTaskSupport import scala.concurrent.forkjoin.ForkJoinPool @@ -764,11 +766,11 @@ object DDLUtils { val HIVE_PROVIDER = "hive" def isHiveTable(table: CatalogTable): Boolean = { - table.provider.isDefined && table.provider.get.toLowerCase == HIVE_PROVIDER + table.provider.isDefined && table.provider.get.toLowerCase(Locale.ROOT) == HIVE_PROVIDER } def isDatasourceTable(table: CatalogTable): Boolean = { - table.provider.isDefined && table.provider.get.toLowerCase != HIVE_PROVIDER + table.provider.isDefined && table.provider.get.toLowerCase(Locale.ROOT) != HIVE_PROVIDER } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala index ea5398761c46..5687f9332430 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.command +import java.util.Locale + import org.apache.spark.sql.{AnalysisException, Row, SparkSession} import org.apache.spark.sql.catalyst.FunctionIdentifier import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, NoSuchFunctionException} @@ -100,7 +102,7 @@ case class DescribeFunctionCommand( override def run(sparkSession: SparkSession): Seq[Row] = { // Hard code "<>", "!=", "between", and "case" for now as there is no corresponding functions. - functionName.funcName.toLowerCase match { + functionName.funcName.toLowerCase(Locale.ROOT) match { case "<>" => Row(s"Function: $functionName") :: Row("Usage: expr1 <> expr2 - " + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index c9384e44255b..f3b209deaae5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -17,12 +17,11 @@ package org.apache.spark.sql.execution.datasources -import java.util.{ServiceConfigurationError, ServiceLoader} +import java.util.{Locale, ServiceConfigurationError, ServiceLoader} import scala.collection.JavaConverters._ import scala.language.{existentials, implicitConversions} import scala.util.{Failure, Success, Try} -import scala.util.control.NonFatal import org.apache.hadoop.fs.Path @@ -539,15 +538,16 @@ object DataSource { // Found the data source using fully qualified path dataSource case Failure(error) => - if (provider1.toLowerCase == "orc" || + if (provider1.toLowerCase(Locale.ROOT) == "orc" || provider1.startsWith("org.apache.spark.sql.hive.orc")) { throw new AnalysisException( "The ORC data source must be used with Hive support enabled") - } else if (provider1.toLowerCase == "avro" || + } else if (provider1.toLowerCase(Locale.ROOT) == "avro" || provider1 == "com.databricks.spark.avro") { throw new AnalysisException( - s"Failed to find data source: ${provider1.toLowerCase}. Please find an Avro " + - "package at http://spark.apache.org/third-party-projects.html") + s"Failed to find data source: ${provider1.toLowerCase(Locale.ROOT)}. " + + "Please find an Avro package at " + + "http://spark.apache.org/third-party-projects.html") } else { throw new ClassNotFoundException( s"Failed to find data source: $provider1. Please find packages at " + @@ -596,8 +596,8 @@ object DataSource { */ def buildStorageFormatFromOptions(options: Map[String, String]): CatalogStorageFormat = { val path = CaseInsensitiveMap(options).get("path") - val optionsWithoutPath = options.filterKeys(_.toLowerCase != "path") + val optionsWithoutPath = options.filterKeys(_.toLowerCase(Locale.ROOT) != "path") CatalogStorageFormat.empty.copy( - locationUri = path.map(CatalogUtils.stringToURI(_)), properties = optionsWithoutPath) + locationUri = path.map(CatalogUtils.stringToURI), properties = optionsWithoutPath) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala index 11605dd28056..9897ab73b0da 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala @@ -245,7 +245,6 @@ object InMemoryFileIndex extends Logging { sessionOpt: Option[SparkSession]): Seq[FileStatus] = { logTrace(s"Listing $path") val fs = path.getFileSystem(hadoopConf) - val name = path.getName.toLowerCase // [SPARK-17599] Prevent InMemoryFileIndex from failing if path doesn't exist // Note that statuses only include FileStatus for the files and dirs directly under path, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala index 03980922ab38..c3583209efc5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.datasources import java.lang.{Double => JDouble, Long => JLong} import java.math.{BigDecimal => JBigDecimal} -import java.util.TimeZone +import java.util.{Locale, TimeZone} import scala.collection.mutable.ArrayBuffer import scala.util.Try @@ -194,7 +194,7 @@ object PartitioningUtils { while (!finished) { // Sometimes (e.g., when speculative task is enabled), temporary directories may be left // uncleaned. Here we simply ignore them. - if (currentPath.getName.toLowerCase == "_temporary") { + if (currentPath.getName.toLowerCase(Locale.ROOT) == "_temporary") { return (None, None) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala index 4994b8dc8052..62e4c6e4b4ea 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala @@ -71,9 +71,9 @@ class CSVOptions( val param = parameters.getOrElse(paramName, default.toString) if (param == null) { default - } else if (param.toLowerCase == "true") { + } else if (param.toLowerCase(Locale.ROOT) == "true") { true - } else if (param.toLowerCase == "false") { + } else if (param.toLowerCase(Locale.ROOT) == "false") { false } else { throw new Exception(s"$paramName flag can be true or false") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala index 110d503f91cf..f8d4a9bb5b81 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.datasources +import java.util.Locale + import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogUtils} @@ -75,7 +77,7 @@ case class CreateTempViewUsing( } def run(sparkSession: SparkSession): Seq[Row] = { - if (provider.toLowerCase == DDLUtils.HIVE_PROVIDER) { + if (provider.toLowerCase(Locale.ROOT) == DDLUtils.HIVE_PROVIDER) { throw new AnalysisException("Hive data source can only be used with tables, " + "you can't use it with CREATE TEMP VIEW USING") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala index 89fe86c038b1..591096d5efd2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.datasources.jdbc import java.sql.{Connection, DriverManager} -import java.util.Properties +import java.util.{Locale, Properties} import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap @@ -55,7 +55,7 @@ class JDBCOptions( */ val asConnectionProperties: Properties = { val properties = new Properties() - parameters.originalMap.filterKeys(key => !jdbcOptionNames(key.toLowerCase)) + parameters.originalMap.filterKeys(key => !jdbcOptionNames(key.toLowerCase(Locale.ROOT))) .foreach { case (k, v) => properties.setProperty(k, v) } properties } @@ -141,7 +141,7 @@ object JDBCOptions { private val jdbcOptionNames = collection.mutable.Set[String]() private def newOption(name: String): String = { - jdbcOptionNames += name.toLowerCase + jdbcOptionNames += name.toLowerCase(Locale.ROOT) name } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index 774d1ba19432..5fc3c2753b6c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.datasources.jdbc import java.sql.{Connection, Driver, DriverManager, PreparedStatement, ResultSet, ResultSetMetaData, SQLException} +import java.util.Locale import scala.collection.JavaConverters._ import scala.util.Try @@ -542,7 +543,7 @@ object JdbcUtils extends Logging { case ArrayType(et, _) => // remove type length parameters from end of type name val typeName = getJdbcType(et, dialect).databaseTypeDefinition - .toLowerCase.split("\\(")(0) + .toLowerCase(Locale.ROOT).split("\\(")(0) (stmt: PreparedStatement, row: Row, pos: Int) => val array = conn.createArrayOf( typeName, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala index bdda299a621a..772d4565de54 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.datasources.parquet +import java.util.Locale + import org.apache.parquet.hadoop.metadata.CompressionCodecName import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap @@ -40,9 +42,11 @@ private[parquet] class ParquetOptions( * Acceptable values are defined in [[shortParquetCompressionCodecNames]]. */ val compressionCodecClassName: String = { - val codecName = parameters.getOrElse("compression", sqlConf.parquetCompressionCodec).toLowerCase + val codecName = parameters.getOrElse("compression", + sqlConf.parquetCompressionCodec).toLowerCase(Locale.ROOT) if (!shortParquetCompressionCodecNames.contains(codecName)) { - val availableCodecs = shortParquetCompressionCodecNames.keys.map(_.toLowerCase) + val availableCodecs = + shortParquetCompressionCodecNames.keys.map(_.toLowerCase(Locale.ROOT)) throw new IllegalArgumentException(s"Codec [$codecName] " + s"is not available. Available codecs are ${availableCodecs.mkString(", ")}.") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala index 8b598cc60e77..7abf2ae5166b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.datasources +import java.util.Locale + import org.apache.spark.sql.{AnalysisException, SaveMode, SparkSession} import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.catalog._ @@ -48,7 +50,8 @@ class ResolveSQLOnFile(sparkSession: SparkSession) extends Rule[LogicalPlan] { // will catch it and return the original plan, so that the analyzer can report table not // found later. val isFileFormat = classOf[FileFormat].isAssignableFrom(dataSource.providingClass) - if (!isFileFormat || dataSource.className.toLowerCase == DDLUtils.HIVE_PROVIDER) { + if (!isFileFormat || + dataSource.className.toLowerCase(Locale.ROOT) == DDLUtils.HIVE_PROVIDER) { throw new AnalysisException("Unsupported data source type for direct query on files: " + s"${u.tableIdentifier.database.get}") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala index f9dd80230e48..1426728f9b55 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.streaming.state import java.io.{DataInputStream, DataOutputStream, FileNotFoundException, IOException} +import java.util.Locale import scala.collection.JavaConverters._ import scala.collection.mutable @@ -599,7 +600,7 @@ private[state] class HDFSBackedStateStoreProvider( val nameParts = path.getName.split("\\.") if (nameParts.size == 2) { val version = nameParts(0).toLong - nameParts(1).toLowerCase match { + nameParts(1).toLowerCase(Locale.ROOT) match { case "delta" => // ignore the file otherwise, snapshot file already exists for that batch id if (!versionToFiles.contains(version)) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/HiveSerDe.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/HiveSerDe.scala index ca46a1151e3e..b9515ec7bca2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/HiveSerDe.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/HiveSerDe.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.internal +import java.util.Locale + import org.apache.spark.sql.catalyst.catalog.CatalogStorageFormat case class HiveSerDe( @@ -68,7 +70,7 @@ object HiveSerDe { * @return HiveSerDe associated with the specified source */ def sourceToSerDe(source: String): Option[HiveSerDe] = { - val key = source.toLowerCase match { + val key = source.toLowerCase(Locale.ROOT) match { case s if s.startsWith("org.apache.spark.sql.parquet") => "parquet" case s if s.startsWith("org.apache.spark.sql.orc") => "orc" case s if s.equals("orcfile") => "orc" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala index 1ef9d52713d9..0289471bf841 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala @@ -21,7 +21,6 @@ import scala.reflect.ClassTag import scala.util.control.NonFatal import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.Path import org.apache.spark.{SparkConf, SparkContext, SparkException} import org.apache.spark.internal.Logging diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index c3a9cfc08517..746b2a94f102 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.streaming +import java.util.Locale + import scala.collection.JavaConverters._ import org.apache.spark.annotation.{Experimental, InterfaceStability} @@ -135,7 +137,7 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo * @since 2.0.0 */ def load(): DataFrame = { - if (source.toLowerCase == DDLUtils.HIVE_PROVIDER) { + if (source.toLowerCase(Locale.ROOT) == DDLUtils.HIVE_PROVIDER) { throw new AnalysisException("Hive data source can only be used with tables, you can not " + "read files of Hive data source directly.") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala index f2f700590ca8..0d2611f9bbcc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.streaming +import java.util.Locale + import scala.collection.JavaConverters._ import org.apache.spark.annotation.{Experimental, InterfaceStability} @@ -230,7 +232,7 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { * @since 2.0.0 */ def start(): StreamingQuery = { - if (source.toLowerCase == DDLUtils.HIVE_PROVIDER) { + if (source.toLowerCase(Locale.ROOT) == DDLUtils.HIVE_PROVIDER) { throw new AnalysisException("Hive data source can only be used with tables, you can not " + "write files of Hive data source directly.") } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index 78cf033dd81d..3ba37addfc8b 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -119,7 +119,7 @@ public void testCommonOperation() { Dataset parMapped = ds.mapPartitions((MapPartitionsFunction) it -> { List ls = new LinkedList<>(); while (it.hasNext()) { - ls.add(it.next().toUpperCase(Locale.ENGLISH)); + ls.add(it.next().toUpperCase(Locale.ROOT)); } return ls.iterator(); }, Encoders.STRING()); diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala index 4b69baffab62..d9130fdcfaea 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala @@ -124,7 +124,8 @@ class SQLQueryTestSuite extends QueryTest with SharedSQLContext { } private def createScalaTestCase(testCase: TestCase): Unit = { - if (blackList.exists(t => testCase.name.toLowerCase.contains(t.toLowerCase))) { + if (blackList.exists(t => + testCase.name.toLowerCase(Locale.ROOT).contains(t.toLowerCase(Locale.ROOT)))) { // Create a test case to ignore this case. ignore(testCase.name) { /* Do nothing */ } } else { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala index 8bceab39f71d..1c1931b6a6da 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala @@ -16,6 +16,8 @@ */ package org.apache.spark.sql.execution +import java.util.Locale + import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, OneRowRelation} import org.apache.spark.sql.test.SharedSQLContext @@ -24,11 +26,12 @@ class QueryExecutionSuite extends SharedSQLContext { test("toString() exception/error handling") { val badRule = new SparkStrategy { var mode: String = "" - override def apply(plan: LogicalPlan): Seq[SparkPlan] = mode.toLowerCase match { - case "exception" => throw new AnalysisException(mode) - case "error" => throw new Error(mode) - case _ => Nil - } + override def apply(plan: LogicalPlan): Seq[SparkPlan] = + mode.toLowerCase(Locale.ROOT) match { + case "exception" => throw new AnalysisException(mode) + case "error" => throw new Error(mode) + case _ => Nil + } } spark.experimental.extraStrategies = badRule :: Nil diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala index 13202a57851e..97c61dc8694b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.command import java.net.URI +import java.util.Locale import scala.reflect.{classTag, ClassTag} @@ -40,8 +41,10 @@ class DDLCommandSuite extends PlanTest { val e = intercept[ParseException] { parser.parsePlan(sql) } - assert(e.getMessage.toLowerCase.contains("operation not allowed")) - containsThesePhrases.foreach { p => assert(e.getMessage.toLowerCase.contains(p.toLowerCase)) } + assert(e.getMessage.toLowerCase(Locale.ROOT).contains("operation not allowed")) + containsThesePhrases.foreach { p => + assert(e.getMessage.toLowerCase(Locale.ROOT).contains(p.toLowerCase(Locale.ROOT))) + } } private def parseAs[T: ClassTag](query: String): T = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index 9ebf2dd839a7..fe74ab49f91b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.command import java.io.File import java.net.URI +import java.util.Locale import org.apache.hadoop.fs.Path import org.scalatest.BeforeAndAfterEach @@ -190,7 +191,7 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { val e = intercept[AnalysisException] { sql(query) } - assert(e.getMessage.toLowerCase.contains("operation not allowed")) + assert(e.getMessage.toLowerCase(Locale.ROOT).contains("operation not allowed")) } private def maybeWrapException[T](expectException: Boolean)(body: => T): Unit = { @@ -1813,7 +1814,7 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { withTable(tabName) { sql(s"CREATE TABLE $tabName(col1 int, col2 string) USING parquet ") val message = intercept[AnalysisException] { - sql(s"SHOW COLUMNS IN $db.showcolumn FROM ${db.toUpperCase}") + sql(s"SHOW COLUMNS IN $db.showcolumn FROM ${db.toUpperCase(Locale.ROOT)}") }.getMessage assert(message.contains("SHOW COLUMNS with conflicting databases")) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala index 57a0af1dda97..94a2f9a00b3f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.datasources.parquet +import java.util.Locale + import scala.collection.JavaConverters._ import scala.collection.mutable import scala.reflect.ClassTag @@ -300,7 +302,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { def checkCompressionCodec(codec: CompressionCodecName): Unit = { withSQLConf(SQLConf.PARQUET_COMPRESSION.key -> codec.name()) { withParquetFile(data) { path => - assertResult(spark.conf.get(SQLConf.PARQUET_COMPRESSION).toUpperCase) { + assertResult(spark.conf.get(SQLConf.PARQUET_COMPRESSION).toUpperCase(Locale.ROOT)) { compressionCodecFor(path, codec.name()) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala index 2b20b9716bf8..b4f3de996120 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.datasources.parquet import java.io.File import java.math.BigInteger import java.sql.{Date, Timestamp} -import java.util.{Calendar, TimeZone} +import java.util.{Calendar, Locale, TimeZone} import scala.collection.mutable.ArrayBuffer @@ -476,7 +476,8 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha assert(partDf.schema.map(_.name) === Seq("intField", "stringField")) path.listFiles().foreach { f => - if (!f.getName.startsWith("_") && f.getName.toLowerCase().endsWith(".parquet")) { + if (!f.getName.startsWith("_") && + f.getName.toLowerCase(Locale.ROOT).endsWith(".parquet")) { // when the input is a path to a parquet file val df = spark.read.parquet(f.getCanonicalPath) assert(df.schema.map(_.name) === Seq("intField", "stringField")) @@ -484,7 +485,8 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha } path.listFiles().foreach { f => - if (!f.getName.startsWith("_") && f.getName.toLowerCase().endsWith(".parquet")) { + if (!f.getName.startsWith("_") && + f.getName.toLowerCase(Locale.ROOT).endsWith(".parquet")) { // when the input is a path to a parquet file but `basePath` is overridden to // the base path containing partitioning directories val df = spark diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala index be56c964a18f..5a0388ec1d1d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.sources +import java.util.Locale + import scala.language.existentials import org.apache.spark.rdd.RDD @@ -76,7 +78,7 @@ case class SimpleFilteredScan(from: Int, to: Int)(@transient val sparkSession: S case "b" => (i: Int) => Seq(i * 2) case "c" => (i: Int) => val c = (i - 1 + 'a').toChar.toString - Seq(c * 5 + c.toUpperCase * 5) + Seq(c * 5 + c.toUpperCase(Locale.ROOT) * 5) } FiltersPushed.list = filters @@ -113,7 +115,8 @@ case class SimpleFilteredScan(from: Int, to: Int)(@transient val sparkSession: S } def eval(a: Int) = { - val c = (a - 1 + 'a').toChar.toString * 5 + (a - 1 + 'a').toChar.toString.toUpperCase * 5 + val c = (a - 1 + 'a').toChar.toString * 5 + + (a - 1 + 'a').toChar.toString.toUpperCase(Locale.ROOT) * 5 filters.forall(translateFilterOnA(_)(a)) && filters.forall(translateFilterOnC(_)(c)) } @@ -151,7 +154,7 @@ class FilteredScanSuite extends DataSourceTest with SharedSQLContext with Predic sqlTest( "SELECT * FROM oneToTenFiltered", (1 to 10).map(i => Row(i, i * 2, (i - 1 + 'a').toChar.toString * 5 - + (i - 1 + 'a').toChar.toString.toUpperCase * 5)).toSeq) + + (i - 1 + 'a').toChar.toString.toUpperCase(Locale.ROOT) * 5)).toSeq) sqlTest( "SELECT a, b FROM oneToTenFiltered", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala index f67444fbc49d..1211242b9fbb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.streaming +import java.util.Locale + import org.apache.spark.sql.{AnalysisException, DataFrame} import org.apache.spark.sql.execution.DataSourceScanExec import org.apache.spark.sql.execution.datasources._ @@ -221,7 +223,7 @@ class FileStreamSinkSuite extends StreamTest { df.writeStream.format("parquet").outputMode(mode).start(dir.getCanonicalPath) } Seq(mode, "not support").foreach { w => - assert(e.getMessage.toLowerCase.contains(w)) + assert(e.getMessage.toLowerCase(Locale.ROOT).contains(w)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala index e5d5b4f32882..f796a4cb4a39 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.streaming -import java.util.TimeZone +import java.util.{Locale, TimeZone} import org.scalatest.BeforeAndAfterAll @@ -105,7 +105,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest with BeforeAndAfte testStream(aggregated, Append)() } Seq("append", "not supported").foreach { m => - assert(e.getMessage.toLowerCase.contains(m.toLowerCase)) + assert(e.getMessage.toLowerCase(Locale.ROOT).contains(m.toLowerCase(Locale.ROOT))) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala index 05cd3d9f7c2f..dc2506a48ad0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.streaming.test import java.io.File +import java.util.Locale import java.util.concurrent.TimeUnit import scala.concurrent.duration._ @@ -126,7 +127,7 @@ class DataStreamReaderWriterSuite extends StreamTest with BeforeAndAfter { .save() } Seq("'write'", "not", "streaming Dataset/DataFrame").foreach { s => - assert(e.getMessage.toLowerCase.contains(s.toLowerCase)) + assert(e.getMessage.toLowerCase(Locale.ROOT).contains(s.toLowerCase(Locale.ROOT))) } } @@ -400,7 +401,7 @@ class DataStreamReaderWriterSuite extends StreamTest with BeforeAndAfter { var w = df.writeStream var e = intercept[IllegalArgumentException](w.foreach(null)) Seq("foreach", "null").foreach { s => - assert(e.getMessage.toLowerCase.contains(s.toLowerCase)) + assert(e.getMessage.toLowerCase(Locale.ROOT).contains(s.toLowerCase(Locale.ROOT))) } } @@ -417,7 +418,7 @@ class DataStreamReaderWriterSuite extends StreamTest with BeforeAndAfter { var w = df.writeStream.partitionBy("value") var e = intercept[AnalysisException](w.foreach(foreachWriter).start()) Seq("foreach", "partitioning").foreach { s => - assert(e.getMessage.toLowerCase.contains(s.toLowerCase)) + assert(e.getMessage.toLowerCase(Locale.ROOT).contains(s.toLowerCase(Locale.ROOT))) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala index 7c71e7280c6d..fb15e7def6db 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.test import java.io.File +import java.util.Locale import java.util.concurrent.ConcurrentLinkedQueue import org.scalatest.BeforeAndAfter @@ -144,7 +145,7 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with Be .start() } Seq("'writeStream'", "only", "streaming Dataset/DataFrame").foreach { s => - assert(e.getMessage.toLowerCase.contains(s.toLowerCase)) + assert(e.getMessage.toLowerCase(Locale.ROOT).contains(s.toLowerCase(Locale.ROOT))) } } @@ -276,13 +277,13 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with Be var w = df.write.partitionBy("value") var e = intercept[AnalysisException](w.jdbc(null, null, null)) Seq("jdbc", "partitioning").foreach { s => - assert(e.getMessage.toLowerCase.contains(s.toLowerCase)) + assert(e.getMessage.toLowerCase(Locale.ROOT).contains(s.toLowerCase(Locale.ROOT))) } w = df.write.bucketBy(2, "value") e = intercept[AnalysisException](w.jdbc(null, null, null)) Seq("jdbc", "bucketing").foreach { s => - assert(e.getMessage.toLowerCase.contains(s.toLowerCase)) + assert(e.getMessage.toLowerCase(Locale.ROOT).contains(s.toLowerCase(Locale.ROOT))) } } @@ -385,7 +386,8 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with Be // Reader, with user specified schema, should just apply user schema on the file data val e = intercept[AnalysisException] { spark.read.schema(userSchema).textFile() } - assert(e.getMessage.toLowerCase.contains("user specified schema not supported")) + assert(e.getMessage.toLowerCase(Locale.ROOT).contains( + "user specified schema not supported")) intercept[AnalysisException] { spark.read.schema(userSchema).textFile(dir) } intercept[AnalysisException] { spark.read.schema(userSchema).textFile(dir, dir) } intercept[AnalysisException] { spark.read.schema(userSchema).textFile(Seq(dir, dir): _*) } diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/HiveAuthFactory.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/HiveAuthFactory.java index 1e6ac4f3df47..c5ade6528304 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/HiveAuthFactory.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/HiveAuthFactory.java @@ -24,6 +24,7 @@ import java.util.Arrays; import java.util.HashMap; import java.util.List; +import java.util.Locale; import java.util.Map; import javax.net.ssl.SSLServerSocket; @@ -259,12 +260,12 @@ public static TServerSocket getServerSSLSocket(String hiveHost, int portNum, Str if (thriftServerSocket.getServerSocket() instanceof SSLServerSocket) { List sslVersionBlacklistLocal = new ArrayList(); for (String sslVersion : sslVersionBlacklist) { - sslVersionBlacklistLocal.add(sslVersion.trim().toLowerCase()); + sslVersionBlacklistLocal.add(sslVersion.trim().toLowerCase(Locale.ROOT)); } SSLServerSocket sslServerSocket = (SSLServerSocket) thriftServerSocket.getServerSocket(); List enabledProtocols = new ArrayList(); for (String protocol : sslServerSocket.getEnabledProtocols()) { - if (sslVersionBlacklistLocal.contains(protocol.toLowerCase())) { + if (sslVersionBlacklistLocal.contains(protocol.toLowerCase(Locale.ROOT))) { LOG.debug("Disabling SSL Protocol: " + protocol); } else { enabledProtocols.add(protocol); diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/SaslQOP.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/SaslQOP.java index ab3ac6285aa0..ad4dfd75f470 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/SaslQOP.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/SaslQOP.java @@ -19,6 +19,7 @@ package org.apache.hive.service.auth; import java.util.HashMap; +import java.util.Locale; import java.util.Map; /** @@ -52,7 +53,7 @@ public String toString() { public static SaslQOP fromString(String str) { if (str != null) { - str = str.toLowerCase(); + str = str.toLowerCase(Locale.ROOT); } SaslQOP saslQOP = STR_TO_ENUM.get(str); if (saslQOP == null) { diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/Type.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/Type.java index a96d2ac371cd..7752ec03a29b 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/Type.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/Type.java @@ -19,6 +19,7 @@ package org.apache.hive.service.cli; import java.sql.DatabaseMetaData; +import java.util.Locale; import org.apache.hadoop.hive.common.type.HiveDecimal; import org.apache.hive.service.cli.thrift.TTypeId; @@ -160,7 +161,7 @@ public static Type getType(String name) { if (name.equalsIgnoreCase(type.name)) { return type; } else if (type.isQualifiedType() || type.isComplexType()) { - if (name.toUpperCase().startsWith(type.name)) { + if (name.toUpperCase(Locale.ROOT).startsWith(type.name)) { return type; } } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala index 14553601b1d5..5e4734ad3ad2 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala @@ -294,7 +294,7 @@ private[hive] class HiveThriftServer2(sqlContext: SQLContext) private def isHTTPTransportMode(hiveConf: HiveConf): Boolean = { val transportMode = hiveConf.getVar(ConfVars.HIVE_SERVER2_TRANSPORT_MODE) - transportMode.toLowerCase(Locale.ENGLISH).equals("http") + transportMode.toLowerCase(Locale.ROOT).equals("http") } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala index 1bc5c3c62f04..d5cc3b385504 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala @@ -302,7 +302,7 @@ private[hive] class SparkSQLCLIDriver extends CliDriver with Logging { override def processCmd(cmd: String): Int = { val cmd_trimmed: String = cmd.trim() - val cmd_lower = cmd_trimmed.toLowerCase(Locale.ENGLISH) + val cmd_lower = cmd_trimmed.toLowerCase(Locale.ROOT) val tokens: Array[String] = cmd_trimmed.split("\\s+") val cmd_1: String = cmd_trimmed.substring(tokens(0).length()).trim() if (cmd_lower.equals("quit") || @@ -310,7 +310,7 @@ private[hive] class SparkSQLCLIDriver extends CliDriver with Logging { sessionState.close() System.exit(0) } - if (tokens(0).toLowerCase(Locale.ENGLISH).equals("source") || + if (tokens(0).toLowerCase(Locale.ROOT).equals("source") || cmd_trimmed.startsWith("!") || isRemoteMode) { val start = System.currentTimeMillis() super.processCmd(cmd) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index f0e35dff57f7..806f2be5faeb 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.hive import java.io.IOException import java.lang.reflect.InvocationTargetException import java.util +import java.util.Locale import scala.collection.mutable import scala.util.control.NonFatal @@ -499,7 +500,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat // We can't use `filterKeys` here, as the map returned by `filterKeys` is not serializable, // while `CatalogTable` should be serializable. val propsWithoutPath = table.storage.properties.filter { - case (k, v) => k.toLowerCase != "path" + case (k, v) => k.toLowerCase(Locale.ROOT) != "path" } table.storage.copy(properties = propsWithoutPath ++ newPath.map("path" -> _)) } @@ -1060,7 +1061,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat // Hive's metastore is case insensitive. However, Hive's createFunction does // not normalize the function name (unlike the getFunction part). So, // we are normalizing the function name. - val functionName = funcDefinition.identifier.funcName.toLowerCase + val functionName = funcDefinition.identifier.funcName.toLowerCase(Locale.ROOT) requireFunctionNotExists(db, functionName) val functionIdentifier = funcDefinition.identifier.copy(funcName = functionName) client.createFunction(db, funcDefinition.copy(identifier = functionIdentifier)) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala index 9e3eb2dd8234..c917f110b90f 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.hive +import java.util.Locale + import scala.util.{Failure, Success, Try} import scala.util.control.NonFatal @@ -143,7 +145,7 @@ private[sql] class HiveSessionCatalog( // This function is not in functionRegistry, let's try to load it as a Hive's // built-in function. // Hive is case insensitive. - val functionName = funcName.unquotedString.toLowerCase + val functionName = funcName.unquotedString.toLowerCase(Locale.ROOT) if (!hiveFunctions.contains(functionName)) { failFunctionLookup(funcName.unquotedString) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index 0465e9c031e2..09a5eda6e543 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.hive import java.io.IOException +import java.util.Locale import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.hive.common.StatsSetupConst @@ -184,14 +185,14 @@ case class RelationConversions( conf: SQLConf, sessionCatalog: HiveSessionCatalog) extends Rule[LogicalPlan] { private def isConvertible(relation: CatalogRelation): Boolean = { - (relation.tableMeta.storage.serde.getOrElse("").toLowerCase.contains("parquet") && - conf.getConf(HiveUtils.CONVERT_METASTORE_PARQUET)) || - (relation.tableMeta.storage.serde.getOrElse("").toLowerCase.contains("orc") && - conf.getConf(HiveUtils.CONVERT_METASTORE_ORC)) + val serde = relation.tableMeta.storage.serde.getOrElse("").toLowerCase(Locale.ROOT) + serde.contains("parquet") && conf.getConf(HiveUtils.CONVERT_METASTORE_PARQUET) || + serde.contains("orc") && conf.getConf(HiveUtils.CONVERT_METASTORE_ORC) } private def convert(relation: CatalogRelation): LogicalRelation = { - if (relation.tableMeta.storage.serde.getOrElse("").toLowerCase.contains("parquet")) { + val serde = relation.tableMeta.storage.serde.getOrElse("").toLowerCase(Locale.ROOT) + if (serde.contains("parquet")) { val options = Map(ParquetOptions.MERGE_SCHEMA -> conf.getConf(HiveUtils.CONVERT_METASTORE_PARQUET_WITH_SCHEMA_MERGING).toString) sessionCatalog.metastoreCatalog diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala index afc2bf85334d..3de60c7fc131 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala @@ -21,6 +21,7 @@ import java.io.File import java.net.{URL, URLClassLoader} import java.nio.charset.StandardCharsets import java.sql.Timestamp +import java.util.Locale import java.util.concurrent.TimeUnit import scala.collection.mutable.HashMap @@ -338,7 +339,7 @@ private[spark] object HiveUtils extends Logging { logWarning(s"Hive jar path '$path' does not exist.") Nil } else { - files.filter(_.getName.toLowerCase.endsWith(".jar")) + files.filter(_.getName.toLowerCase(Locale.ROOT).endsWith(".jar")) } case path => new File(path) :: Nil diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala index 56ccac32a8d8..387ec4f96723 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.hive.client import java.io.{File, PrintStream} +import java.util.Locale import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer @@ -153,7 +154,7 @@ private[hive] class HiveClientImpl( hadoopConf.iterator().asScala.foreach { entry => val key = entry.getKey val value = entry.getValue - if (key.toLowerCase.contains("password")) { + if (key.toLowerCase(Locale.ROOT).contains("password")) { logDebug(s"Applying Hadoop and Hive config to Hive Conf: $key=xxx") } else { logDebug(s"Applying Hadoop and Hive config to Hive Conf: $key=$value") @@ -168,7 +169,7 @@ private[hive] class HiveClientImpl( hiveConf.setClassLoader(initClassLoader) // 2: we set all spark confs to this hiveConf. sparkConf.getAll.foreach { case (k, v) => - if (k.toLowerCase.contains("password")) { + if (k.toLowerCase(Locale.ROOT).contains("password")) { logDebug(s"Applying Spark config to Hive Conf: $k=xxx") } else { logDebug(s"Applying Spark config to Hive Conf: $k=$v") @@ -177,7 +178,7 @@ private[hive] class HiveClientImpl( } // 3: we set all entries in config to this hiveConf. extraConfig.foreach { case (k, v) => - if (k.toLowerCase.contains("password")) { + if (k.toLowerCase(Locale.ROOT).contains("password")) { logDebug(s"Applying extra config to HiveConf: $k=xxx") } else { logDebug(s"Applying extra config to HiveConf: $k=$v") @@ -622,7 +623,7 @@ private[hive] class HiveClientImpl( */ protected def runHive(cmd: String, maxRows: Int = 1000): Seq[String] = withHiveState { logDebug(s"Running hiveql '$cmd'") - if (cmd.toLowerCase.startsWith("set")) { logDebug(s"Changing config: $cmd") } + if (cmd.toLowerCase(Locale.ROOT).startsWith("set")) { logDebug(s"Changing config: $cmd") } try { val cmd_trimmed: String = cmd.trim() val tokens: Array[String] = cmd_trimmed.split("\\s+") diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala index 2e35f3983948..7abb9f06b131 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.hive.client import java.lang.{Boolean => JBoolean, Integer => JInteger, Long => JLong} import java.lang.reflect.{InvocationTargetException, Method, Modifier} import java.net.URI -import java.util.{ArrayList => JArrayList, List => JList, Map => JMap, Set => JSet} +import java.util.{ArrayList => JArrayList, List => JList, Locale, Map => JMap, Set => JSet} import java.util.concurrent.TimeUnit import scala.collection.JavaConverters._ @@ -505,8 +505,8 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { private def toHiveFunction(f: CatalogFunction, db: String): HiveFunction = { val resourceUris = f.resources.map { resource => - new ResourceUri( - ResourceType.valueOf(resource.resourceType.resourceType.toUpperCase()), resource.uri) + new ResourceUri(ResourceType.valueOf( + resource.resourceType.resourceType.toUpperCase(Locale.ROOT)), resource.uri) } new HiveFunction( f.identifier.funcName, diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveOptions.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveOptions.scala index 192851028031..5c515515b9b9 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveOptions.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveOptions.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.hive.execution +import java.util.Locale + import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap /** @@ -29,7 +31,7 @@ class HiveOptions(@transient private val parameters: CaseInsensitiveMap[String]) def this(parameters: Map[String, String]) = this(CaseInsensitiveMap(parameters)) - val fileFormat = parameters.get(FILE_FORMAT).map(_.toLowerCase) + val fileFormat = parameters.get(FILE_FORMAT).map(_.toLowerCase(Locale.ROOT)) val inputFormat = parameters.get(INPUT_FORMAT) val outputFormat = parameters.get(OUTPUT_FORMAT) @@ -75,7 +77,7 @@ class HiveOptions(@transient private val parameters: CaseInsensitiveMap[String]) } def serdeProperties: Map[String, String] = parameters.filterKeys { - k => !lowerCasedOptionNames.contains(k.toLowerCase) + k => !lowerCasedOptionNames.contains(k.toLowerCase(Locale.ROOT)) }.map { case (k, v) => delimiterOptions.getOrElse(k, k) -> v } } @@ -83,7 +85,7 @@ object HiveOptions { private val lowerCasedOptionNames = collection.mutable.Set[String]() private def newOption(name: String): String = { - lowerCasedOptionNames += name.toLowerCase + lowerCasedOptionNames += name.toLowerCase(Locale.ROOT) name } @@ -99,5 +101,5 @@ object HiveOptions { // The following typo is inherited from Hive... "collectionDelim" -> "colelction.delim", "mapkeyDelim" -> "mapkey.delim", - "lineDelim" -> "line.delim").map { case (k, v) => k.toLowerCase -> v } + "lineDelim" -> "line.delim").map { case (k, v) => k.toLowerCase(Locale.ROOT) -> v } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcOptions.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcOptions.scala index ccaa568dcce2..043eb69818ba 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcOptions.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcOptions.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.hive.orc +import java.util.Locale + import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap /** @@ -41,9 +43,9 @@ private[orc] class OrcOptions(@transient private val parameters: CaseInsensitive val codecName = parameters .get("compression") .orElse(orcCompressionConf) - .getOrElse("snappy").toLowerCase + .getOrElse("snappy").toLowerCase(Locale.ROOT) if (!shortOrcCompressionCodecNames.contains(codecName)) { - val availableCodecs = shortOrcCompressionCodecNames.keys.map(_.toLowerCase) + val availableCodecs = shortOrcCompressionCodecNames.keys.map(_.toLowerCase(Locale.ROOT)) throw new IllegalArgumentException(s"Codec [$codecName] " + s"is not available. Available codecs are ${availableCodecs.mkString(", ")}.") } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala index 490e02d0bd54..59cc6605a124 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.hive import java.net.URI +import java.util.Locale import org.apache.spark.sql.{AnalysisException, SaveMode} import org.apache.spark.sql.catalyst.TableIdentifier @@ -49,7 +50,7 @@ class HiveDDLCommandSuite extends PlanTest with SQLTestUtils with TestHiveSingle val e = intercept[ParseException] { parser.parsePlan(sql) } - assert(e.getMessage.toLowerCase.contains("operation not allowed")) + assert(e.getMessage.toLowerCase(Locale.ROOT).contains("operation not allowed")) } private def analyzeCreateTable(sql: String): CatalogTable = { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSchemaInferenceSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSchemaInferenceSuite.scala index e48ce2304d08..319d02613f00 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSchemaInferenceSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSchemaInferenceSuite.scala @@ -18,18 +18,15 @@ package org.apache.spark.sql.hive import java.io.File -import java.util.concurrent.{Executors, TimeUnit} import scala.util.Random import org.scalatest.BeforeAndAfterEach -import org.apache.spark.metrics.source.HiveCatalogMetrics import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.execution.datasources.FileStatusCache import org.apache.spark.sql.QueryTest -import org.apache.spark.sql.hive.client.HiveClient import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.internal.{HiveSerDe, SQLConf} import org.apache.spark.sql.internal.SQLConf.HiveCaseSensitiveInferenceMode.{Value => InferenceMode, _} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala index e45cf977bfaa..abe5d835719b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.hive.execution import java.io._ import java.nio.charset.StandardCharsets import java.util +import java.util.Locale import scala.util.control.NonFatal @@ -299,10 +300,11 @@ abstract class HiveComparisonTest // thus the tables referenced in those DDL commands cannot be extracted for use by our // test table auto-loading mechanism. In addition, the tests which use the SHOW TABLES // command expect these tables to exist. - val hasShowTableCommand = queryList.exists(_.toLowerCase.contains("show tables")) + val hasShowTableCommand = + queryList.exists(_.toLowerCase(Locale.ROOT).contains("show tables")) for (table <- Seq("src", "srcpart")) { val hasMatchingQuery = queryList.exists { query => - val normalizedQuery = query.toLowerCase.stripSuffix(";") + val normalizedQuery = query.toLowerCase(Locale.ROOT).stripSuffix(";") normalizedQuery.endsWith(table) || normalizedQuery.contains(s"from $table") || normalizedQuery.contains(s"from default.$table") @@ -444,7 +446,7 @@ abstract class HiveComparisonTest "create table", "drop index" ) - !queryList.map(_.toLowerCase).exists { query => + !queryList.map(_.toLowerCase(Locale.ROOT)).exists { query => excludedSubstrings.exists(s => query.contains(s)) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 65a902fc5438..cf3376036072 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -80,7 +80,7 @@ class HiveQuerySuite extends HiveComparisonTest with SQLTestUtils with BeforeAnd private def assertUnsupportedFeature(body: => Unit): Unit = { val e = intercept[ParseException] { body } - assert(e.getMessage.toLowerCase.contains("operation not allowed")) + assert(e.getMessage.toLowerCase(Locale.ROOT).contains("operation not allowed")) } // Testing the Broadcast based join for cartesian join (cross join) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index d012797e1992..75f3744ff35b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.hive.execution import java.io.File import java.nio.charset.StandardCharsets import java.sql.{Date, Timestamp} +import java.util.Locale import com.google.common.io.Files import org.apache.hadoop.fs.Path @@ -475,13 +476,13 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { case None => // OK. } // Also make sure that the format and serde are as desired. - assert(catalogTable.storage.inputFormat.get.toLowerCase.contains(format)) - assert(catalogTable.storage.outputFormat.get.toLowerCase.contains(format)) + assert(catalogTable.storage.inputFormat.get.toLowerCase(Locale.ROOT).contains(format)) + assert(catalogTable.storage.outputFormat.get.toLowerCase(Locale.ROOT).contains(format)) val serde = catalogTable.storage.serde.get format match { case "sequence" | "text" => assert(serde.contains("LazySimpleSerDe")) case "rcfile" => assert(serde.contains("LazyBinaryColumnarSerDe")) - case _ => assert(serde.toLowerCase.contains(format)) + case _ => assert(serde.toLowerCase(Locale.ROOT).contains(format)) } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala index 9a760e2947d0..931f015f03b6 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala @@ -17,6 +17,8 @@ package org.apache.spark.streaming.dstream +import java.util.Locale + import scala.reflect.ClassTag import org.apache.spark.SparkContext @@ -60,7 +62,7 @@ abstract class InputDStream[T: ClassTag](_ssc: StreamingContext) .split("(?=[A-Z])") .filter(_.nonEmpty) .mkString(" ") - .toLowerCase + .toLowerCase(Locale.ROOT) .capitalize s"$newName [$id]" } @@ -74,7 +76,7 @@ abstract class InputDStream[T: ClassTag](_ssc: StreamingContext) protected[streaming] override val baseScope: Option[String] = { val scopeName = Option(ssc.sc.getLocalProperty(SparkContext.RDD_SCOPE_KEY)) .map { json => RDDOperationScope.fromJson(json).name + s" [$id]" } - .getOrElse(name.toLowerCase) + .getOrElse(name.toLowerCase(Locale.ROOT)) Some(new RDDOperationScope(scopeName).toJson) } diff --git a/streaming/src/test/java/test/org/apache/spark/streaming/Java8APISuite.java b/streaming/src/test/java/test/org/apache/spark/streaming/Java8APISuite.java index 80513de4ee11..90d1f8c5035b 100644 --- a/streaming/src/test/java/test/org/apache/spark/streaming/Java8APISuite.java +++ b/streaming/src/test/java/test/org/apache/spark/streaming/Java8APISuite.java @@ -101,7 +101,7 @@ public void testMapPartitions() { JavaDStream mapped = stream.mapPartitions(in -> { String out = ""; while (in.hasNext()) { - out = out + in.next().toUpperCase(); + out = out + in.next().toUpperCase(Locale.ROOT); } return Arrays.asList(out).iterator(); }); @@ -806,7 +806,8 @@ public void testMapValues() { ssc, inputData, 1); JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); - JavaPairDStream mapped = pairStream.mapValues(String::toUpperCase); + JavaPairDStream mapped = + pairStream.mapValues(s -> s.toUpperCase(Locale.ROOT)); JavaTestUtils.attachTestOutputStream(mapped); List>> result = JavaTestUtils.runStreams(ssc, 2, 2); diff --git a/streaming/src/test/java/test/org/apache/spark/streaming/JavaAPISuite.java b/streaming/src/test/java/test/org/apache/spark/streaming/JavaAPISuite.java index 96f8d9593d63..6c86cacec827 100644 --- a/streaming/src/test/java/test/org/apache/spark/streaming/JavaAPISuite.java +++ b/streaming/src/test/java/test/org/apache/spark/streaming/JavaAPISuite.java @@ -267,7 +267,7 @@ public void testMapPartitions() { JavaDStream mapped = stream.mapPartitions(in -> { StringBuilder out = new StringBuilder(); while (in.hasNext()) { - out.append(in.next().toUpperCase(Locale.ENGLISH)); + out.append(in.next().toUpperCase(Locale.ROOT)); } return Arrays.asList(out.toString()).iterator(); }); @@ -1315,7 +1315,7 @@ public void testMapValues() { JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); JavaPairDStream mapped = - pairStream.mapValues(s -> s.toUpperCase(Locale.ENGLISH)); + pairStream.mapValues(s -> s.toUpperCase(Locale.ROOT)); JavaTestUtils.attachTestOutputStream(mapped); List>> result = JavaTestUtils.runStreams(ssc, 2, 2); diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala index 5645996de5a6..eb996c93ff38 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.streaming import java.io.{File, NotSerializableException} +import java.util.Locale import java.util.concurrent.{CountDownLatch, TimeUnit} import java.util.concurrent.atomic.AtomicInteger @@ -745,7 +746,7 @@ class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with Timeo val ex = intercept[IllegalStateException] { body } - assert(ex.getMessage.toLowerCase().contains(expectedErrorMsg)) + assert(ex.getMessage.toLowerCase(Locale.ROOT).contains(expectedErrorMsg)) } } From f6dd8e0e1673aa491b895c1f0467655fa4e9d52f Mon Sep 17 00:00:00 2001 From: Bogdan Raducanu Date: Mon, 10 Apr 2017 21:56:21 +0200 Subject: [PATCH 251/512] [SPARK-20280][CORE] FileStatusCache Weigher integer overflow ## What changes were proposed in this pull request? Weigher.weigh needs to return Int but it is possible for an Array[FileStatus] to have size > Int.maxValue. To avoid this, the size is scaled down by a factor of 32. The maximumWeight of the cache is also scaled down by the same factor. ## How was this patch tested? New test in FileIndexSuite Author: Bogdan Raducanu Closes #17591 from bogdanrdc/SPARK-20280. --- .../datasources/FileStatusCache.scala | 47 ++++++++++++++----- .../datasources/FileIndexSuite.scala | 16 +++++++ 2 files changed, 50 insertions(+), 13 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileStatusCache.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileStatusCache.scala index 5d9755863314..aea27bd4c4d7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileStatusCache.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileStatusCache.scala @@ -94,27 +94,48 @@ private class SharedInMemoryCache(maxSizeInBytes: Long) extends Logging { // Opaque object that uniquely identifies a shared cache user private type ClientId = Object + private val warnedAboutEviction = new AtomicBoolean(false) // we use a composite cache key in order to distinguish entries inserted by different clients - private val cache: Cache[(ClientId, Path), Array[FileStatus]] = CacheBuilder.newBuilder() - .weigher(new Weigher[(ClientId, Path), Array[FileStatus]] { + private val cache: Cache[(ClientId, Path), Array[FileStatus]] = { + // [[Weigher]].weigh returns Int so we could only cache objects < 2GB + // instead, the weight is divided by this factor (which is smaller + // than the size of one [[FileStatus]]). + // so it will support objects up to 64GB in size. + val weightScale = 32 + val weigher = new Weigher[(ClientId, Path), Array[FileStatus]] { override def weigh(key: (ClientId, Path), value: Array[FileStatus]): Int = { - (SizeEstimator.estimate(key) + SizeEstimator.estimate(value)).toInt - }}) - .removalListener(new RemovalListener[(ClientId, Path), Array[FileStatus]]() { - override def onRemoval(removed: RemovalNotification[(ClientId, Path), Array[FileStatus]]) - : Unit = { + val estimate = (SizeEstimator.estimate(key) + SizeEstimator.estimate(value)) / weightScale + if (estimate > Int.MaxValue) { + logWarning(s"Cached table partition metadata size is too big. Approximating to " + + s"${Int.MaxValue.toLong * weightScale}.") + Int.MaxValue + } else { + estimate.toInt + } + } + } + val removalListener = new RemovalListener[(ClientId, Path), Array[FileStatus]]() { + override def onRemoval( + removed: RemovalNotification[(ClientId, Path), + Array[FileStatus]]): Unit = { if (removed.getCause == RemovalCause.SIZE && - warnedAboutEviction.compareAndSet(false, true)) { + warnedAboutEviction.compareAndSet(false, true)) { logWarning( "Evicting cached table partition metadata from memory due to size constraints " + - "(spark.sql.hive.filesourcePartitionFileCacheSize = " + maxSizeInBytes + " bytes). " + - "This may impact query planning performance.") + "(spark.sql.hive.filesourcePartitionFileCacheSize = " + + maxSizeInBytes + " bytes). This may impact query planning performance.") } - }}) - .maximumWeight(maxSizeInBytes) - .build[(ClientId, Path), Array[FileStatus]]() + } + } + CacheBuilder.newBuilder() + .weigher(weigher) + .removalListener(removalListener) + .maximumWeight(maxSizeInBytes / weightScale) + .build[(ClientId, Path), Array[FileStatus]]() + } + /** * @return a FileStatusCache that does not share any entries with any other client, but does diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileIndexSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileIndexSuite.scala index 00f5d5db8f5f..a9511cbd9e4c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileIndexSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileIndexSuite.scala @@ -29,6 +29,7 @@ import org.apache.spark.metrics.source.HiveCatalogMetrics import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.util.{KnownSizeEstimation, SizeEstimator} class FileIndexSuite extends SharedSQLContext { @@ -220,6 +221,21 @@ class FileIndexSuite extends SharedSQLContext { assert(catalog.leafDirPaths.head == fs.makeQualified(dirPath)) } } + + test("SPARK-20280 - FileStatusCache with a partition with very many files") { + /* fake the size, otherwise we need to allocate 2GB of data to trigger this bug */ + class MyFileStatus extends FileStatus with KnownSizeEstimation { + override def estimatedSize: Long = 1000 * 1000 * 1000 + } + /* files * MyFileStatus.estimatedSize should overflow to negative integer + * so, make it between 2bn and 4bn + */ + val files = (1 to 3).map { i => + new MyFileStatus() + } + val fileStatusCache = FileStatusCache.getOrCreate(spark) + fileStatusCache.putLeafFiles(new Path("/tmp", "abc"), files.toArray) + } } class FakeParentPathFileSystem extends RawLocalFileSystem { From f9a50ba2d1bfa3f55199df031e71154611ba51f6 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Mon, 10 Apr 2017 14:06:49 -0700 Subject: [PATCH 252/512] [SPARK-20285][TESTS] Increase the pyspark streaming test timeout to 30 seconds ## What changes were proposed in this pull request? Saw the following failure locally: ``` Traceback (most recent call last): File "/home/jenkins/workspace/python/pyspark/streaming/tests.py", line 351, in test_cogroup self._test_func(input, func, expected, sort=True, input2=input2) File "/home/jenkins/workspace/python/pyspark/streaming/tests.py", line 162, in _test_func self.assertEqual(expected, result) AssertionError: Lists differ: [[(1, ([1], [2])), (2, ([1], [... != [] First list contains 3 additional elements. First extra element 0: [(1, ([1], [2])), (2, ([1], [])), (3, ([1], []))] + [] - [[(1, ([1], [2])), (2, ([1], [])), (3, ([1], []))], - [(1, ([1, 1, 1], [])), (2, ([1], [])), (4, ([], [1]))], - [('', ([1, 1], [1, 2])), ('a', ([1, 1], [1, 1])), ('b', ([1], [1]))]] ``` It also happened on Jenkins: http://spark-tests.appspot.com/builds/spark-branch-2.1-test-sbt-hadoop-2.7/120 It's because when the machine is overloaded, the timeout is not enough. This PR just increases the timeout to 30 seconds. ## How was this patch tested? Jenkins Author: Shixiong Zhu Closes #17597 from zsxwing/SPARK-20285. --- python/pyspark/streaming/tests.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 1bec33509580..ffba99502b14 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -55,7 +55,7 @@ class PySparkStreamingTestCase(unittest.TestCase): - timeout = 10 # seconds + timeout = 30 # seconds duration = .5 @classmethod From a35b9d97123697d23fa0f691c1054f9adab5956c Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Mon, 10 Apr 2017 14:09:32 -0700 Subject: [PATCH 253/512] [SPARK-20282][SS][TESTS] Write the commit log first to fix a race contion in tests ## What changes were proposed in this pull request? This PR fixes the following failure: ``` sbt.ForkMain$ForkError: org.scalatest.exceptions.TestFailedException: Assert on query failed: == Progress == AssertOnQuery(, ) StopStream AddData to MemoryStream[value#30891]: 1,2 StartStream(OneTimeTrigger,org.apache.spark.util.SystemClock35cdc93a,Map()) CheckAnswer: [6],[3] StopStream => AssertOnQuery(, ) AssertOnQuery(, ) StartStream(OneTimeTrigger,org.apache.spark.util.SystemClockcdb247d,Map()) CheckAnswer: [6],[3] StopStream AddData to MemoryStream[value#30891]: 3 StartStream(OneTimeTrigger,org.apache.spark.util.SystemClock55394e4d,Map()) CheckLastBatch: [2] StopStream AddData to MemoryStream[value#30891]: 0 StartStream(OneTimeTrigger,org.apache.spark.util.SystemClock749aa997,Map()) ExpectFailure[org.apache.spark.SparkException, isFatalError: false] AssertOnQuery(, ) AssertOnQuery(, incorrect start offset or end offset on exception) == Stream == Output Mode: Append Stream state: not started Thread state: dead == Sink == 0: [6] [3] == Plan == at org.scalatest.Assertions$class.newAssertionFailedException(Assertions.scala:495) at org.scalatest.FunSuite.newAssertionFailedException(FunSuite.scala:1555) at org.scalatest.Assertions$class.fail(Assertions.scala:1328) at org.scalatest.FunSuite.fail(FunSuite.scala:1555) at org.apache.spark.sql.streaming.StreamTest$class.failTest$1(StreamTest.scala:347) at org.apache.spark.sql.streaming.StreamTest$class.verify$1(StreamTest.scala:318) at org.apache.spark.sql.streaming.StreamTest$$anonfun$liftedTree1$1$1.apply(StreamTest.scala:483) at org.apache.spark.sql.streaming.StreamTest$$anonfun$liftedTree1$1$1.apply(StreamTest.scala:357) at scala.collection.mutable.ResizableArray$class.foreach(ResizableArray.scala:59) at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:48) at org.apache.spark.sql.streaming.StreamTest$class.liftedTree1$1(StreamTest.scala:357) at org.apache.spark.sql.streaming.StreamTest$class.testStream(StreamTest.scala:356) at org.apache.spark.sql.streaming.StreamingQuerySuite.testStream(StreamingQuerySuite.scala:41) at org.apache.spark.sql.streaming.StreamingQuerySuite$$anonfun$6.apply$mcV$sp(StreamingQuerySuite.scala:166) at org.apache.spark.sql.streaming.StreamingQuerySuite$$anonfun$6.apply(StreamingQuerySuite.scala:161) at org.apache.spark.sql.streaming.StreamingQuerySuite$$anonfun$6.apply(StreamingQuerySuite.scala:161) at org.apache.spark.sql.catalyst.util.package$.quietly(package.scala:42) at org.apache.spark.sql.test.SQLTestUtils$$anonfun$testQuietly$1.apply$mcV$sp(SQLTestUtils.scala:268) at org.apache.spark.sql.test.SQLTestUtils$$anonfun$testQuietly$1.apply(SQLTestUtils.scala:268) at org.apache.spark.sql.test.SQLTestUtils$$anonfun$testQuietly$1.apply(SQLTestUtils.scala:268) at org.scalatest.Transformer$$anonfun$apply$1.apply$mcV$sp(Transformer.scala:22) at org.scalatest.OutcomeOf$class.outcomeOf(OutcomeOf.scala:85) at org.scalatest.OutcomeOf$.outcomeOf(OutcomeOf.scala:104) at org.scalatest.Transformer.apply(Transformer.scala:22) at org.scalatest.Transformer.apply(Transformer.scala:20) at org.scalatest.FunSuiteLike$$anon$1.apply(FunSuiteLike.scala:166) at org.apache.spark.SparkFunSuite.withFixture(SparkFunSuite.scala:68) at org.scalatest.FunSuiteLike$class.invokeWithFixture$1(FunSuiteLike.scala:163) at org.scalatest.FunSuiteLike$$anonfun$runTest$1.apply(FunSuiteLike.scala:175) at org.scalatest.FunSuiteLike$$anonfun$runTest$1.apply(FunSuiteLike.scala:175) at org.scalatest.SuperEngine.runTestImpl(Engine.scala:306) at org.scalatest.FunSuiteLike$class.runTest(FunSuiteLike.scala:175) at org.apache.spark.sql.streaming.StreamingQuerySuite.org$scalatest$BeforeAndAfterEach$$super$runTest(StreamingQuerySuite.scala:41) at org.scalatest.BeforeAndAfterEach$class.runTest(BeforeAndAfterEach.scala:255) at org.apache.spark.sql.streaming.StreamingQuerySuite.org$scalatest$BeforeAndAfter$$super$runTest(StreamingQuerySuite.scala:41) at org.scalatest.BeforeAndAfter$class.runTest(BeforeAndAfter.scala:200) at org.apache.spark.sql.streaming.StreamingQuerySuite.runTest(StreamingQuerySuite.scala:41) at org.scalatest.FunSuiteLike$$anonfun$runTests$1.apply(FunSuiteLike.scala:208) at org.scalatest.FunSuiteLike$$anonfun$runTests$1.apply(FunSuiteLike.scala:208) at org.scalatest.SuperEngine$$anonfun$traverseSubNodes$1$1.apply(Engine.scala:413) at org.scalatest.SuperEngine$$anonfun$traverseSubNodes$1$1.apply(Engine.scala:401) at scala.collection.immutable.List.foreach(List.scala:381) at org.scalatest.SuperEngine.traverseSubNodes$1(Engine.scala:401) at org.scalatest.SuperEngine.org$scalatest$SuperEngine$$runTestsInBranch(Engine.scala:396) at org.scalatest.SuperEngine.runTestsImpl(Engine.scala:483) at org.scalatest.FunSuiteLike$class.runTests(FunSuiteLike.scala:208) at org.scalatest.FunSuite.runTests(FunSuite.scala:1555) at org.scalatest.Suite$class.run(Suite.scala:1424) at org.scalatest.FunSuite.org$scalatest$FunSuiteLike$$super$run(FunSuite.scala:1555) at org.scalatest.FunSuiteLike$$anonfun$run$1.apply(FunSuiteLike.scala:212) at org.scalatest.FunSuiteLike$$anonfun$run$1.apply(FunSuiteLike.scala:212) at org.scalatest.SuperEngine.runImpl(Engine.scala:545) at org.scalatest.FunSuiteLike$class.run(FunSuiteLike.scala:212) at org.apache.spark.SparkFunSuite.org$scalatest$BeforeAndAfterAll$$super$run(SparkFunSuite.scala:31) at org.scalatest.BeforeAndAfterAll$class.liftedTree1$1(BeforeAndAfterAll.scala:257) at org.scalatest.BeforeAndAfterAll$class.run(BeforeAndAfterAll.scala:256) at org.apache.spark.sql.streaming.StreamingQuerySuite.org$scalatest$BeforeAndAfter$$super$run(StreamingQuerySuite.scala:41) at org.scalatest.BeforeAndAfter$class.run(BeforeAndAfter.scala:241) at org.apache.spark.sql.streaming.StreamingQuerySuite.run(StreamingQuerySuite.scala:41) at org.scalatest.tools.Framework.org$scalatest$tools$Framework$$runSuite(Framework.scala:357) at org.scalatest.tools.Framework$ScalaTestTask.execute(Framework.scala:502) at sbt.ForkMain$Run$2.call(ForkMain.java:296) at sbt.ForkMain$Run$2.call(ForkMain.java:286) at java.util.concurrent.FutureTask.run(FutureTask.java:266) at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142) at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617) at java.lang.Thread.run(Thread.java:745) ``` The failure is because `CheckAnswer` will run once `committedOffsets` is updated. Then writing the commit log may be interrupted by the following `StopStream`. This PR just change the order to write the commit log first. ## How was this patch tested? Jenkins Author: Shixiong Zhu Closes #17594 from zsxwing/SPARK-20282. --- .../apache/spark/sql/execution/streaming/StreamExecution.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index 5f548172f5ce..8857966676ae 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -304,8 +304,8 @@ class StreamExecution( finishTrigger(dataAvailable) if (dataAvailable) { // Update committed offsets. - committedOffsets ++= availableOffsets batchCommitLog.add(currentBatchId) + committedOffsets ++= availableOffsets logDebug(s"batch ${currentBatchId} committed") // We'll increase currentBatchId after we complete processing current batch's data currentBatchId += 1 From 379b0b0bbdbba2278ce3bcf471bd75f6ffd9cf0d Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 10 Apr 2017 14:14:09 -0700 Subject: [PATCH 254/512] [SPARK-20283][SQL] Add preOptimizationBatches ## What changes were proposed in this pull request? We currently have postHocOptimizationBatches, but not preOptimizationBatches. This patch adds preOptimizationBatches so the optimizer debugging extensions are symmetric. ## How was this patch tested? N/A Author: Reynold Xin Closes #17595 from rxin/SPARK-20283. --- .../org/apache/spark/sql/execution/SparkOptimizer.scala | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala index 2cdfb7a7828c..1de4f508b89a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala @@ -30,13 +30,19 @@ class SparkOptimizer( experimentalMethods: ExperimentalMethods) extends Optimizer(catalog, conf) { - override def batches: Seq[Batch] = (super.batches :+ + override def batches: Seq[Batch] = (preOptimizationBatches ++ super.batches :+ Batch("Optimize Metadata Only Query", Once, OptimizeMetadataOnlyQuery(catalog, conf)) :+ Batch("Extract Python UDF from Aggregate", Once, ExtractPythonUDFFromAggregate) :+ Batch("Prune File Source Table Partitions", Once, PruneFileSourcePartitions)) ++ postHocOptimizationBatches :+ Batch("User Provided Optimizers", fixedPoint, experimentalMethods.extraOptimizations: _*) + /** + * Optimization batches that are executed before the regular optimization batches (also before + * the finish analysis batch). + */ + def preOptimizationBatches: Seq[Batch] = Nil + /** * Optimization batches that are executed after the regular optimization batches, but before the * batch executing the [[ExperimentalMethods]] optimizer rules. This hook can be used to add From 734dfbfcfea1ed1ab3a5f18f84c412a569dd87e7 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Mon, 10 Apr 2017 20:41:08 -0700 Subject: [PATCH 255/512] [SPARK-17564][TESTS] Fix flaky RequestTimeoutIntegrationSuite.furtherRequestsDelay ## What changes were proposed in this pull request? This PR fixs the following failure: ``` sbt.ForkMain$ForkError: java.lang.AssertionError: null at org.junit.Assert.fail(Assert.java:86) at org.junit.Assert.assertTrue(Assert.java:41) at org.junit.Assert.assertTrue(Assert.java:52) at org.apache.spark.network.RequestTimeoutIntegrationSuite.furtherRequestsDelay(RequestTimeoutIntegrationSuite.java:230) at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method) at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62) at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43) at java.lang.reflect.Method.invoke(Method.java:497) at org.junit.runners.model.FrameworkMethod$1.runReflectiveCall(FrameworkMethod.java:50) at org.junit.internal.runners.model.ReflectiveCallable.run(ReflectiveCallable.java:12) at org.junit.runners.model.FrameworkMethod.invokeExplosively(FrameworkMethod.java:47) at org.junit.internal.runners.statements.InvokeMethod.evaluate(InvokeMethod.java:17) at org.junit.internal.runners.statements.RunBefores.evaluate(RunBefores.java:26) at org.junit.internal.runners.statements.RunAfters.evaluate(RunAfters.java:27) at org.junit.runners.ParentRunner.runLeaf(ParentRunner.java:325) at org.junit.runners.BlockJUnit4ClassRunner.runChild(BlockJUnit4ClassRunner.java:78) at org.junit.runners.BlockJUnit4ClassRunner.runChild(BlockJUnit4ClassRunner.java:57) at org.junit.runners.ParentRunner$3.run(ParentRunner.java:290) at org.junit.runners.ParentRunner$1.schedule(ParentRunner.java:71) at org.junit.runners.ParentRunner.runChildren(ParentRunner.java:288) at org.junit.runners.ParentRunner.access$000(ParentRunner.java:58) at org.junit.runners.ParentRunner$2.evaluate(ParentRunner.java:268) at org.junit.runners.ParentRunner.run(ParentRunner.java:363) at org.junit.runners.Suite.runChild(Suite.java:128) at org.junit.runners.Suite.runChild(Suite.java:27) at org.junit.runners.ParentRunner$3.run(ParentRunner.java:290) at org.junit.runners.ParentRunner$1.schedule(ParentRunner.java:71) at org.junit.runners.ParentRunner.runChildren(ParentRunner.java:288) at org.junit.runners.ParentRunner.access$000(ParentRunner.java:58) at org.junit.runners.ParentRunner$2.evaluate(ParentRunner.java:268) at org.junit.runners.ParentRunner.run(ParentRunner.java:363) at org.junit.runner.JUnitCore.run(JUnitCore.java:137) at org.junit.runner.JUnitCore.run(JUnitCore.java:115) at com.novocode.junit.JUnitRunner$1.execute(JUnitRunner.java:132) at sbt.ForkMain$Run$2.call(ForkMain.java:296) at sbt.ForkMain$Run$2.call(ForkMain.java:286) at java.util.concurrent.FutureTask.run(FutureTask.java:266) at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142) at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617) at java.lang.Thread.run(Thread.java:745) ``` It happens several times per month on [Jenkins](http://spark-tests.appspot.com/test-details?suite_name=org.apache.spark.network.RequestTimeoutIntegrationSuite&test_name=furtherRequestsDelay). The failure is because `callback1` may not be called before `assertTrue(callback1.failure instanceof IOException);`. It's pretty easy to reproduce this error by adding a sleep before this line: https://github.com/apache/spark/blob/379b0b0bbdbba2278ce3bcf471bd75f6ffd9cf0d/common/network-common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java#L267 The fix is straightforward: just use the latch to wait until `callback1` is called. ## How was this patch tested? Jenkins Author: Shixiong Zhu Closes #17599 from zsxwing/SPARK-17564. --- .../apache/spark/network/RequestTimeoutIntegrationSuite.java | 2 ++ 1 file changed, 2 insertions(+) diff --git a/common/network-common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java b/common/network-common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java index 9aa17e24b624..c0724e018263 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java @@ -225,6 +225,8 @@ public StreamManager getStreamManager() { callback0.latch.await(60, TimeUnit.SECONDS); assertTrue(callback0.failure instanceof IOException); + // make sure callback1 is called. + callback1.latch.await(60, TimeUnit.SECONDS); // failed at same time as previous assertTrue(callback1.failure instanceof IOException); } From 0d2b796427a59d3e9967b62618be301307f29162 Mon Sep 17 00:00:00 2001 From: Benjamin Fradet Date: Tue, 11 Apr 2017 09:12:49 +0200 Subject: [PATCH 256/512] [SPARK-20097][ML] Fix visibility discrepancy with numInstances and degreesOfFreedom in LR and GLR ## What changes were proposed in this pull request? - made `numInstances` public in GLR - made `degreesOfFreedom` public in LR ## How was this patch tested? reran the concerned test suites Author: Benjamin Fradet Closes #17431 from BenFradet/SPARK-20097. --- .../spark/ml/regression/GeneralizedLinearRegression.scala | 3 ++- .../org/apache/spark/ml/regression/LinearRegression.scala | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala index 33137b0c0fde..d6093a01c671 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala @@ -1133,7 +1133,8 @@ class GeneralizedLinearRegressionSummary private[regression] ( private[regression] lazy val link: Link = familyLink.link /** Number of instances in DataFrame predictions. */ - private[regression] lazy val numInstances: Long = predictions.count() + @Since("2.2.0") + lazy val numInstances: Long = predictions.count() /** The numeric rank of the fitted linear model. */ @Since("2.0.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index 45df1d9be647..f7e3c8fa5b6e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -696,7 +696,8 @@ class LinearRegressionSummary private[regression] ( lazy val numInstances: Long = predictions.count() /** Degrees of freedom */ - private val degreesOfFreedom: Long = if (privateModel.getFitIntercept) { + @Since("2.2.0") + val degreesOfFreedom: Long = if (privateModel.getFitIntercept) { numInstances - privateModel.coefficients.size - 1 } else { numInstances - privateModel.coefficients.size From d11ef3d77ec2136d6b28bd69f5dd2cc0a22e4717 Mon Sep 17 00:00:00 2001 From: MirrorZ Date: Tue, 11 Apr 2017 10:34:39 +0100 Subject: [PATCH 257/512] Document Master URL format in high availability set up ## What changes were proposed in this pull request? Add documentation for adding master url in multi host, port format for standalone cluster with high availability with zookeeper. Referring documentation [Standby Masters with ZooKeeper](http://spark.apache.org/docs/latest/spark-standalone.html#standby-masters-with-zookeeper) ## How was this patch tested? Documenting the functionality already present. Author: MirrorZ Closes #17584 from MirrorZ/master. --- docs/submitting-applications.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/docs/submitting-applications.md b/docs/submitting-applications.md index d23dbcf10d95..866d6e527549 100644 --- a/docs/submitting-applications.md +++ b/docs/submitting-applications.md @@ -143,6 +143,9 @@ The master URL passed to Spark can be in one of the following formats: spark://HOST:PORT Connect to the given Spark standalone cluster master. The port must be whichever one your master is configured to use, which is 7077 by default. + spark://HOST1:PORT1,HOST2:PORT2 Connect to the given Spark standalone + cluster with standby masters with Zookeeper. The list must have all the master hosts in the high availability cluster set up with Zookeeper. The port must be whichever each master is configured to use, which is 7077 by default. + mesos://HOST:PORT Connect to the given Mesos cluster. The port must be whichever one your is configured to use, which is 5050 by default. Or, for a Mesos cluster using ZooKeeper, use mesos://zk://.... From c8706980ae07362ae5963829e9ada5007eada46b Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 11 Apr 2017 20:21:04 +0800 Subject: [PATCH 258/512] [SPARK-20274][SQL] support compatible array element type in encoder ## What changes were proposed in this pull request? This is a regression caused by SPARK-19716. Before SPARK-19716, we will cast an array field to the expected array type. However, after SPARK-19716, the cast is removed, but we forgot to push the cast to the element level. ## How was this patch tested? new regression tests Author: Wenchen Fan Closes #17587 from cloud-fan/array. --- .../spark/sql/catalyst/ScalaReflection.scala | 18 +++++++++------ .../sql/catalyst/analysis/Analyzer.scala | 8 +++++-- .../encoders/EncoderResolutionSuite.scala | 23 +++++++++++++++++++ 3 files changed, 40 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 198122759e4a..0c5a818f54f5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -132,7 +132,7 @@ object ScalaReflection extends ScalaReflection { def deserializerFor[T : TypeTag]: Expression = { val tpe = localTypeOf[T] val clsName = getClassNameFromType(tpe) - val walkedTypePath = s"""- root class: "${clsName}"""" :: Nil + val walkedTypePath = s"""- root class: "$clsName"""" :: Nil deserializerFor(tpe, None, walkedTypePath) } @@ -270,12 +270,14 @@ object ScalaReflection extends ScalaReflection { case t if t <:< localTypeOf[Array[_]] => val TypeRef(_, _, Seq(elementType)) = t - val Schema(_, elementNullable) = schemaFor(elementType) + val Schema(dataType, elementNullable) = schemaFor(elementType) val className = getClassNameFromType(elementType) val newTypePath = s"""- array element class: "$className"""" +: walkedTypePath - val mapFunction: Expression => Expression = p => { - val converter = deserializerFor(elementType, Some(p), newTypePath) + val mapFunction: Expression => Expression = element => { + // upcast the array element to the data type the encoder expected. + val casted = upCastToExpectedType(element, dataType, newTypePath) + val converter = deserializerFor(elementType, Some(casted), newTypePath) if (elementNullable) { converter } else { @@ -305,12 +307,14 @@ object ScalaReflection extends ScalaReflection { case t if t <:< localTypeOf[Seq[_]] => val TypeRef(_, _, Seq(elementType)) = t - val Schema(_, elementNullable) = schemaFor(elementType) + val Schema(dataType, elementNullable) = schemaFor(elementType) val className = getClassNameFromType(elementType) val newTypePath = s"""- array element class: "$className"""" +: walkedTypePath - val mapFunction: Expression => Expression = p => { - val converter = deserializerFor(elementType, Some(p), newTypePath) + val mapFunction: Expression => Expression = element => { + // upcast the array element to the data type the encoder expected. + val casted = upCastToExpectedType(element, dataType, newTypePath) + val converter = deserializerFor(elementType, Some(casted), newTypePath) if (elementNullable) { converter } else { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index b0cdef70297c..9816b33ae8df 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.encoders.OuterScopes import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.catalyst.expressions.objects.{MapObjects, NewInstance, UnresolvedMapObjects} +import org.apache.spark.sql.catalyst.expressions.objects.{LambdaVariable, MapObjects, NewInstance, UnresolvedMapObjects} import org.apache.spark.sql.catalyst.expressions.SubExprUtils._ import org.apache.spark.sql.catalyst.optimizer.BooleanSimplification import org.apache.spark.sql.catalyst.plans._ @@ -2321,7 +2321,11 @@ class Analyzer( */ object ResolveUpCast extends Rule[LogicalPlan] { private def fail(from: Expression, to: DataType, walkedTypePath: Seq[String]) = { - throw new AnalysisException(s"Cannot up cast ${from.sql} from " + + val fromStr = from match { + case l: LambdaVariable => "array element" + case e => e.sql + } + throw new AnalysisException(s"Cannot up cast $fromStr from " + s"${from.dataType.simpleString} to ${to.simpleString} as it may truncate\n" + "The type path of the target object is:\n" + walkedTypePath.mkString("", "\n", "\n") + "You can either add an explicit cast to the input data or choose a higher precision " + diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala index e5a3e1fd374d..630e8a7990e7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala @@ -33,6 +33,8 @@ case class StringIntClass(a: String, b: Int) case class ComplexClass(a: Long, b: StringLongClass) +case class PrimitiveArrayClass(arr: Array[Long]) + case class ArrayClass(arr: Seq[StringIntClass]) case class NestedArrayClass(nestedArr: Array[ArrayClass]) @@ -66,6 +68,27 @@ class EncoderResolutionSuite extends PlanTest { encoder.resolveAndBind(attrs).fromRow(InternalRow(InternalRow(str, 1.toByte), 2)) } + test("real type doesn't match encoder schema but they are compatible: primitive array") { + val encoder = ExpressionEncoder[PrimitiveArrayClass] + val attrs = Seq('arr.array(IntegerType)) + val array = new GenericArrayData(Array(1, 2, 3)) + encoder.resolveAndBind(attrs).fromRow(InternalRow(array)) + } + + test("the real type is not compatible with encoder schema: primitive array") { + val encoder = ExpressionEncoder[PrimitiveArrayClass] + val attrs = Seq('arr.array(StringType)) + assert(intercept[AnalysisException](encoder.resolveAndBind(attrs)).message == + s""" + |Cannot up cast array element from string to bigint as it may truncate + |The type path of the target object is: + |- array element class: "scala.Long" + |- field (class: "scala.Array", name: "arr") + |- root class: "org.apache.spark.sql.catalyst.encoders.PrimitiveArrayClass" + |You can either add an explicit cast to the input data or choose a higher precision type + """.stripMargin.trim + " of the field in the target object") + } + test("real type doesn't match encoder schema but they are compatible: array") { val encoder = ExpressionEncoder[ArrayClass] val attrs = Seq('arr.array(new StructType().add("a", "int").add("b", "int").add("c", "int"))) From cd91f967145909852d9af09b10b80f86ed05edb5 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 11 Apr 2017 20:33:10 +0800 Subject: [PATCH 259/512] [SPARK-20175][SQL] Exists should not be evaluated in Join operator ## What changes were proposed in this pull request? Similar to `ListQuery`, `Exists` should not be evaluated in `Join` operator too. ## How was this patch tested? Jenkins tests. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Liang-Chi Hsieh Closes #17491 from viirya/dont-push-exists-to-join. --- .../spark/sql/catalyst/expressions/predicates.scala | 3 ++- .../scala/org/apache/spark/sql/SubquerySuite.scala | 10 ++++++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 8acb740f8db8..5034566132f7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -92,11 +92,12 @@ trait PredicateHelper { protected def canEvaluateWithinJoin(expr: Expression): Boolean = expr match { // Non-deterministic expressions are not allowed as join conditions. case e if !e.deterministic => false - case l: ListQuery => + case _: ListQuery | _: Exists => // A ListQuery defines the query which we want to search in an IN subquery expression. // Currently the only way to evaluate an IN subquery is to convert it to a // LeftSemi/LeftAnti/ExistenceJoin by `RewritePredicateSubquery` rule. // It cannot be evaluated as part of a Join operator. + // An Exists shouldn't be push into a Join operator too. false case e: SubqueryExpression => // non-correlated subquery will be replaced as literal diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala index 5fe6667ceca1..0f0199cbe277 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala @@ -844,4 +844,14 @@ class SubquerySuite extends QueryTest with SharedSQLContext { Row(0) :: Row(1) :: Nil) } } + + test("ListQuery and Exists should work even no correlated references") { + checkAnswer( + sql("select * from l, r where l.a = r.c AND (r.d in (select d from r) OR l.a >= 1)"), + Row(2, 1.0, 2, 3.0) :: Row(2, 1.0, 2, 3.0) :: Row(2, 1.0, 2, 3.0) :: + Row(2, 1.0, 2, 3.0) :: Row(3.0, 3.0, 3, 2.0) :: Row(6, null, 6, null) :: Nil) + checkAnswer( + sql("select * from l, r where l.a = r.c + 1 AND (exists (select * from r) OR l.a = r.c)"), + Row(3, 3.0, 2, 3.0) :: Row(3, 3.0, 2, 3.0) :: Nil) + } } From 123b4fbbc331f116b45f11b9f7ecbe0b0575323d Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 11 Apr 2017 11:12:31 -0700 Subject: [PATCH 260/512] [SPARK-20289][SQL] Use StaticInvoke to box primitive types ## What changes were proposed in this pull request? Dataset typed API currently uses NewInstance to box primitive types (i.e. calling the constructor). Instead, it'd be slightly more idiomatic in Java to use PrimitiveType.valueOf, which can be invoked using StaticInvoke expression. ## How was this patch tested? The change should be covered by existing tests for Dataset encoders. Author: Reynold Xin Closes #17604 from rxin/SPARK-20289. --- .../sql/catalyst/JavaTypeInference.scala | 27 +++++++++---------- .../spark/sql/catalyst/ScalaReflection.scala | 14 +++++----- 2 files changed, 20 insertions(+), 21 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index 9d4617dda555..86a73a319ec3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -204,20 +204,19 @@ object JavaTypeInference { typeToken.getRawType match { case c if !inferExternalType(c).isInstanceOf[ObjectType] => getPath - case c if c == classOf[java.lang.Short] => - NewInstance(c, getPath :: Nil, ObjectType(c)) - case c if c == classOf[java.lang.Integer] => - NewInstance(c, getPath :: Nil, ObjectType(c)) - case c if c == classOf[java.lang.Long] => - NewInstance(c, getPath :: Nil, ObjectType(c)) - case c if c == classOf[java.lang.Double] => - NewInstance(c, getPath :: Nil, ObjectType(c)) - case c if c == classOf[java.lang.Byte] => - NewInstance(c, getPath :: Nil, ObjectType(c)) - case c if c == classOf[java.lang.Float] => - NewInstance(c, getPath :: Nil, ObjectType(c)) - case c if c == classOf[java.lang.Boolean] => - NewInstance(c, getPath :: Nil, ObjectType(c)) + case c if c == classOf[java.lang.Short] || + c == classOf[java.lang.Integer] || + c == classOf[java.lang.Long] || + c == classOf[java.lang.Double] || + c == classOf[java.lang.Float] || + c == classOf[java.lang.Byte] || + c == classOf[java.lang.Boolean] => + StaticInvoke( + c, + ObjectType(c), + "valueOf", + getPath :: Nil, + propagateNull = true) case c if c == classOf[java.sql.Date] => StaticInvoke( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 0c5a818f54f5..82710a2a183a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -204,37 +204,37 @@ object ScalaReflection extends ScalaReflection { case t if t <:< localTypeOf[java.lang.Integer] => val boxedType = classOf[java.lang.Integer] val objectType = ObjectType(boxedType) - NewInstance(boxedType, getPath :: Nil, objectType) + StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, propagateNull = true) case t if t <:< localTypeOf[java.lang.Long] => val boxedType = classOf[java.lang.Long] val objectType = ObjectType(boxedType) - NewInstance(boxedType, getPath :: Nil, objectType) + StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, propagateNull = true) case t if t <:< localTypeOf[java.lang.Double] => val boxedType = classOf[java.lang.Double] val objectType = ObjectType(boxedType) - NewInstance(boxedType, getPath :: Nil, objectType) + StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, propagateNull = true) case t if t <:< localTypeOf[java.lang.Float] => val boxedType = classOf[java.lang.Float] val objectType = ObjectType(boxedType) - NewInstance(boxedType, getPath :: Nil, objectType) + StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, propagateNull = true) case t if t <:< localTypeOf[java.lang.Short] => val boxedType = classOf[java.lang.Short] val objectType = ObjectType(boxedType) - NewInstance(boxedType, getPath :: Nil, objectType) + StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, propagateNull = true) case t if t <:< localTypeOf[java.lang.Byte] => val boxedType = classOf[java.lang.Byte] val objectType = ObjectType(boxedType) - NewInstance(boxedType, getPath :: Nil, objectType) + StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, propagateNull = true) case t if t <:< localTypeOf[java.lang.Boolean] => val boxedType = classOf[java.lang.Boolean] val objectType = ObjectType(boxedType) - NewInstance(boxedType, getPath :: Nil, objectType) + StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, propagateNull = true) case t if t <:< localTypeOf[java.sql.Date] => StaticInvoke( From 6297697f975960a3006c4e58b4964d9ac40eeaf5 Mon Sep 17 00:00:00 2001 From: David Gingrich Date: Tue, 11 Apr 2017 12:18:31 -0700 Subject: [PATCH 261/512] [SPARK-19505][PYTHON] AttributeError on Exception.message in Python3 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? Added `util._message_exception` helper to use `str(e)` when `e.message` is unavailable (Python3). Grepped for all occurrences of `.message` in `pyspark/` and these were the only occurrences. ## How was this patch tested? - Doctests for helper function ## Legal This is my original work and I license the work to the project under the project’s open source license. Author: David Gingrich Closes #16845 from dgingrich/topic-spark-19505-py3-exceptions. --- dev/sparktestsupport/modules.py | 1 + python/pyspark/broadcast.py | 4 ++- python/pyspark/cloudpickle.py | 9 ++++--- python/pyspark/util.py | 45 +++++++++++++++++++++++++++++++++ 4 files changed, 54 insertions(+), 5 deletions(-) create mode 100644 python/pyspark/util.py diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 246f5188a518..78b5b8b0f4b5 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -340,6 +340,7 @@ def __hash__(self): "pyspark.profiler", "pyspark.shuffle", "pyspark.tests", + "pyspark.util", ] ) diff --git a/python/pyspark/broadcast.py b/python/pyspark/broadcast.py index 74dee1420754..b1b59f73d671 100644 --- a/python/pyspark/broadcast.py +++ b/python/pyspark/broadcast.py @@ -21,6 +21,7 @@ from tempfile import NamedTemporaryFile from pyspark.cloudpickle import print_exec +from pyspark.util import _exception_message if sys.version < '3': import cPickle as pickle @@ -82,7 +83,8 @@ def dump(self, value, f): except pickle.PickleError: raise except Exception as e: - msg = "Could not serialize broadcast: " + e.__class__.__name__ + ": " + e.message + msg = "Could not serialize broadcast: %s: %s" \ + % (e.__class__.__name__, _exception_message(e)) print_exec(sys.stderr) raise pickle.PicklingError(msg) f.close() diff --git a/python/pyspark/cloudpickle.py b/python/pyspark/cloudpickle.py index 959fb8b357f9..389bee7eee6e 100644 --- a/python/pyspark/cloudpickle.py +++ b/python/pyspark/cloudpickle.py @@ -56,6 +56,7 @@ import traceback import weakref +from pyspark.util import _exception_message if sys.version < '3': from pickle import Pickler @@ -152,13 +153,13 @@ def dump(self, obj): except pickle.PickleError: raise except Exception as e: - if "'i' format requires" in e.message: - msg = "Object too large to serialize: " + e.message + emsg = _exception_message(e) + if "'i' format requires" in emsg: + msg = "Object too large to serialize: %s" % emsg else: - msg = "Could not serialize object: " + e.__class__.__name__ + ": " + e.message + msg = "Could not serialize object: %s: %s" % (e.__class__.__name__, emsg) print_exec(sys.stderr) raise pickle.PicklingError(msg) - def save_memoryview(self, obj): """Fallback to save_string""" diff --git a/python/pyspark/util.py b/python/pyspark/util.py new file mode 100644 index 000000000000..e5d332ce5442 --- /dev/null +++ b/python/pyspark/util.py @@ -0,0 +1,45 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +__all__ = [] + + +def _exception_message(excp): + """Return the message from an exception as either a str or unicode object. Supports both + Python 2 and Python 3. + + >>> msg = "Exception message" + >>> excp = Exception(msg) + >>> msg == _exception_message(excp) + True + + >>> msg = u"unicöde" + >>> excp = Exception(msg) + >>> msg == _exception_message(excp) + True + """ + if hasattr(excp, "message"): + return excp.message + return str(excp) + + +if __name__ == "__main__": + import doctest + (failure_count, test_count) = doctest.testmod() + if failure_count: + exit(-1) From cde9e328484e4007aa6b505312d7cea5461a6eaf Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Tue, 11 Apr 2017 19:30:34 -0700 Subject: [PATCH 262/512] [MINOR][DOCS] Update supported versions for Hive Metastore ## What changes were proposed in this pull request? Since SPARK-18112 and SPARK-13446, Apache Spark starts to support reading Hive metastore 2.0 ~ 2.1.1. This updates the docs. ## How was this patch tested? N/A Author: Dongjoon Hyun Closes #17612 from dongjoon-hyun/metastore. --- docs/sql-programming-guide.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 7ae9847983d4..c425faca4c27 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1700,7 +1700,7 @@ referencing a singleton. Spark SQL is designed to be compatible with the Hive Metastore, SerDes and UDFs. Currently Hive SerDes and UDFs are based on Hive 1.2.1, and Spark SQL can be connected to different versions of Hive Metastore -(from 0.12.0 to 1.2.1. Also see [Interacting with Different Versions of Hive Metastore] (#interacting-with-different-versions-of-hive-metastore)). +(from 0.12.0 to 2.1.1. Also see [Interacting with Different Versions of Hive Metastore] (#interacting-with-different-versions-of-hive-metastore)). #### Deploying in Existing Hive Warehouses From 8ad63ee158815de5ffff7bf03cdf25aef312095f Mon Sep 17 00:00:00 2001 From: DB Tsai Date: Wed, 12 Apr 2017 11:19:20 +0800 Subject: [PATCH 263/512] [SPARK-20291][SQL] NaNvl(FloatType, NullType) should not be cast to NaNvl(DoubleType, DoubleType) ## What changes were proposed in this pull request? `NaNvl(float value, null)` will be converted into `NaNvl(float value, Cast(null, DoubleType))` and finally `NaNvl(Cast(float value, DoubleType), Cast(null, DoubleType))`. This will cause mismatching in the output type when the input type is float. By adding extra rule in TypeCoercion can resolve this issue. ## How was this patch tested? unite tests. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: DB Tsai Closes #17606 from dbtsai/fixNaNvl. --- .../spark/sql/catalyst/analysis/TypeCoercion.scala | 1 + .../sql/catalyst/analysis/TypeCoercionSuite.scala | 14 ++++++++++---- .../apache/spark/sql/DataFrameNaFunctions.scala | 3 +-- 3 files changed, 12 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 768897dc0713..e1dd010d37a9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -571,6 +571,7 @@ object TypeCoercion { NaNvl(l, Cast(r, DoubleType)) case NaNvl(l, r) if l.dataType == FloatType && r.dataType == DoubleType => NaNvl(Cast(l, DoubleType), r) + case NaNvl(l, r) if r.dataType == NullType => NaNvl(l, Cast(r, l.dataType)) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala index 3e0c357b6de4..011d09ff6064 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala @@ -656,14 +656,20 @@ class TypeCoercionSuite extends PlanTest { test("nanvl casts") { ruleTest(TypeCoercion.FunctionArgumentConversion, - NaNvl(Literal.create(1.0, FloatType), Literal.create(1.0, DoubleType)), - NaNvl(Cast(Literal.create(1.0, FloatType), DoubleType), Literal.create(1.0, DoubleType))) + NaNvl(Literal.create(1.0f, FloatType), Literal.create(1.0, DoubleType)), + NaNvl(Cast(Literal.create(1.0f, FloatType), DoubleType), Literal.create(1.0, DoubleType))) ruleTest(TypeCoercion.FunctionArgumentConversion, - NaNvl(Literal.create(1.0, DoubleType), Literal.create(1.0, FloatType)), - NaNvl(Literal.create(1.0, DoubleType), Cast(Literal.create(1.0, FloatType), DoubleType))) + NaNvl(Literal.create(1.0, DoubleType), Literal.create(1.0f, FloatType)), + NaNvl(Literal.create(1.0, DoubleType), Cast(Literal.create(1.0f, FloatType), DoubleType))) ruleTest(TypeCoercion.FunctionArgumentConversion, NaNvl(Literal.create(1.0, DoubleType), Literal.create(1.0, DoubleType)), NaNvl(Literal.create(1.0, DoubleType), Literal.create(1.0, DoubleType))) + ruleTest(TypeCoercion.FunctionArgumentConversion, + NaNvl(Literal.create(1.0f, FloatType), Literal.create(null, NullType)), + NaNvl(Literal.create(1.0f, FloatType), Cast(Literal.create(null, NullType), FloatType))) + ruleTest(TypeCoercion.FunctionArgumentConversion, + NaNvl(Literal.create(1.0, DoubleType), Literal.create(null, NullType)), + NaNvl(Literal.create(1.0, DoubleType), Cast(Literal.create(null, NullType), DoubleType))) } test("type coercion for If") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala index 93d565d9fe90..052d85ad33bd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala @@ -408,8 +408,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { val quotedColName = "`" + col.name + "`" val colValue = col.dataType match { case DoubleType | FloatType => - // nanvl only supports these types - nanvl(df.col(quotedColName), lit(null).cast(col.dataType)) + nanvl(df.col(quotedColName), lit(null)) // nanvl only supports these types case _ => df.col(quotedColName) } coalesce(colValue, lit(replacement).cast(col.dataType)).as(col.name) From b14bfc3f8e97479ac5927c071b00ed18f2104c95 Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Wed, 12 Apr 2017 12:18:01 +0800 Subject: [PATCH 264/512] [SPARK-19993][SQL] Caching logical plans containing subquery expressions does not work. ## What changes were proposed in this pull request? The sameResult() method does not work when the logical plan contains subquery expressions. **Before the fix** ```SQL scala> val ds = spark.sql("select * from s1 where s1.c1 in (select s2.c1 from s2 where s1.c1 = s2.c1)") ds: org.apache.spark.sql.DataFrame = [c1: int] scala> ds.cache res13: ds.type = [c1: int] scala> spark.sql("select * from s1 where s1.c1 in (select s2.c1 from s2 where s1.c1 = s2.c1)").explain(true) == Analyzed Logical Plan == c1: int Project [c1#86] +- Filter c1#86 IN (list#78 [c1#86]) : +- Project [c1#87] : +- Filter (outer(c1#86) = c1#87) : +- SubqueryAlias s2 : +- Relation[c1#87] parquet +- SubqueryAlias s1 +- Relation[c1#86] parquet == Optimized Logical Plan == Join LeftSemi, ((c1#86 = c1#87) && (c1#86 = c1#87)) :- Relation[c1#86] parquet +- Relation[c1#87] parquet ``` **Plan after fix** ```SQL == Analyzed Logical Plan == c1: int Project [c1#22] +- Filter c1#22 IN (list#14 [c1#22]) : +- Project [c1#23] : +- Filter (outer(c1#22) = c1#23) : +- SubqueryAlias s2 : +- Relation[c1#23] parquet +- SubqueryAlias s1 +- Relation[c1#22] parquet == Optimized Logical Plan == InMemoryRelation [c1#22], true, 10000, StorageLevel(disk, memory, deserialized, 1 replicas) +- *BroadcastHashJoin [c1#1, c1#1], [c1#2, c1#2], LeftSemi, BuildRight :- *FileScan parquet default.s1[c1#1] Batched: true, Format: Parquet, Location: InMemoryFileIndex[file:/Users/dbiswal/mygit/apache/spark/bin/spark-warehouse/s1], PartitionFilters: [], PushedFilters: [], ReadSchema: struct +- BroadcastExchange HashedRelationBroadcastMode(List((shiftleft(cast(input[0, int, true] as bigint), 32) | (cast(input[0, int, true] as bigint) & 4294967295)))) +- *FileScan parquet default.s2[c1#2] Batched: true, Format: Parquet, Location: InMemoryFileIndex[file:/Users/dbiswal/mygit/apache/spark/bin/spark-warehouse/s2], PartitionFilters: [], PushedFilters: [], ReadSchema: struct ``` ## How was this patch tested? New tests are added to CachedTableSuite. Author: Dilip Biswal Closes #17330 from dilipbiswal/subquery_cache_final. --- .../sql/catalyst/expressions/subquery.scala | 26 +++- .../spark/sql/catalyst/plans/QueryPlan.scala | 43 +++--- .../sql/execution/DataSourceScanExec.scala | 7 +- .../apache/spark/sql/CachedTableSuite.scala | 143 +++++++++++++++++- .../hive/execution/HiveTableScanExec.scala | 5 +- 5 files changed, 198 insertions(+), 26 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala index 59db28d58afc..d7b493d521dd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala @@ -47,7 +47,6 @@ abstract class SubqueryExpression( plan: LogicalPlan, children: Seq[Expression], exprId: ExprId) extends PlanExpression[LogicalPlan] { - override lazy val resolved: Boolean = childrenResolved && plan.resolved override lazy val references: AttributeSet = if (plan.resolved) super.references -- plan.outputSet else super.references @@ -59,6 +58,13 @@ abstract class SubqueryExpression( children.zip(p.children).forall(p => p._1.semanticEquals(p._2)) case _ => false } + def canonicalize(attrs: AttributeSeq): SubqueryExpression = { + // Normalize the outer references in the subquery plan. + val normalizedPlan = plan.transformAllExpressions { + case OuterReference(r) => OuterReference(QueryPlan.normalizeExprId(r, attrs)) + } + withNewPlan(normalizedPlan).canonicalized.asInstanceOf[SubqueryExpression] + } } object SubqueryExpression { @@ -236,6 +242,12 @@ case class ScalarSubquery( override def nullable: Boolean = true override def withNewPlan(plan: LogicalPlan): ScalarSubquery = copy(plan = plan) override def toString: String = s"scalar-subquery#${exprId.id} $conditionString" + override lazy val canonicalized: Expression = { + ScalarSubquery( + plan.canonicalized, + children.map(_.canonicalized), + ExprId(0)) + } } object ScalarSubquery { @@ -268,6 +280,12 @@ case class ListQuery( override def nullable: Boolean = false override def withNewPlan(plan: LogicalPlan): ListQuery = copy(plan = plan) override def toString: String = s"list#${exprId.id} $conditionString" + override lazy val canonicalized: Expression = { + ListQuery( + plan.canonicalized, + children.map(_.canonicalized), + ExprId(0)) + } } /** @@ -290,4 +308,10 @@ case class Exists( override def nullable: Boolean = false override def withNewPlan(plan: LogicalPlan): Exists = copy(plan = plan) override def toString: String = s"exists#${exprId.id} $conditionString" + override lazy val canonicalized: Expression = { + Exists( + plan.canonicalized, + children.map(_.canonicalized), + ExprId(0)) + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 3008e8cb8465..2fb65bd43550 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -377,7 +377,8 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT // As the root of the expression, Alias will always take an arbitrary exprId, we need to // normalize that for equality testing, by assigning expr id from 0 incrementally. The // alias name doesn't matter and should be erased. - Alias(normalizeExprId(a.child), "")(ExprId(id), a.qualifier, isGenerated = a.isGenerated) + val normalizedChild = QueryPlan.normalizeExprId(a.child, allAttributes) + Alias(normalizedChild, "")(ExprId(id), a.qualifier, isGenerated = a.isGenerated) case ar: AttributeReference if allAttributes.indexOf(ar.exprId) == -1 => // Top level `AttributeReference` may also be used for output like `Alias`, we should @@ -385,7 +386,7 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT id += 1 ar.withExprId(ExprId(id)) - case other => normalizeExprId(other) + case other => QueryPlan.normalizeExprId(other, allAttributes) }.withNewChildren(canonicalizedChildren) } @@ -395,23 +396,6 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT */ protected def preCanonicalized: PlanType = this - /** - * Normalize the exprIds in the given expression, by updating the exprId in `AttributeReference` - * with its referenced ordinal from input attributes. It's similar to `BindReferences` but we - * do not use `BindReferences` here as the plan may take the expression as a parameter with type - * `Attribute`, and replace it with `BoundReference` will cause error. - */ - protected def normalizeExprId[T <: Expression](e: T, input: AttributeSeq = allAttributes): T = { - e.transformUp { - case ar: AttributeReference => - val ordinal = input.indexOf(ar.exprId) - if (ordinal == -1) { - ar - } else { - ar.withExprId(ExprId(ordinal)) - } - }.canonicalized.asInstanceOf[T] - } /** * Returns true when the given query plan will return the same results as this query plan. @@ -438,3 +422,24 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT */ lazy val allAttributes: AttributeSeq = children.flatMap(_.output) } + +object QueryPlan { + /** + * Normalize the exprIds in the given expression, by updating the exprId in `AttributeReference` + * with its referenced ordinal from input attributes. It's similar to `BindReferences` but we + * do not use `BindReferences` here as the plan may take the expression as a parameter with type + * `Attribute`, and replace it with `BoundReference` will cause error. + */ + def normalizeExprId[T <: Expression](e: T, input: AttributeSeq): T = { + e.transformUp { + case s: SubqueryExpression => s.canonicalize(input) + case ar: AttributeReference => + val ordinal = input.indexOf(ar.exprId) + if (ordinal == -1) { + ar + } else { + ar.withExprId(ExprId(ordinal)) + } + }.canonicalized.asInstanceOf[T] + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index 3a9132d74ac1..866fa9853321 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier} import org.apache.spark.sql.catalyst.catalog.BucketSpec import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext +import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, UnknownPartitioning} import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.parquet.{ParquetFileFormat => ParquetSource} @@ -516,10 +517,10 @@ case class FileSourceScanExec( override lazy val canonicalized: FileSourceScanExec = { FileSourceScanExec( relation, - output.map(normalizeExprId(_, output)), + output.map(QueryPlan.normalizeExprId(_, output)), requiredSchema, - partitionFilters.map(normalizeExprId(_, output)), - dataFilters.map(normalizeExprId(_, output)), + partitionFilters.map(QueryPlan.normalizeExprId(_, output)), + dataFilters.map(QueryPlan.normalizeExprId(_, output)), None) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index 7a7d52b21427..e66fe97afad4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -26,7 +26,7 @@ import org.scalatest.concurrent.Eventually._ import org.apache.spark.CleanerListener import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.expressions.SubqueryExpression -import org.apache.spark.sql.execution.RDDScanExec +import org.apache.spark.sql.execution.{RDDScanExec, SparkPlan} import org.apache.spark.sql.execution.columnar._ import org.apache.spark.sql.execution.exchange.ShuffleExchange import org.apache.spark.sql.functions._ @@ -76,6 +76,13 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext sum } + private def getNumInMemoryTablesRecursively(plan: SparkPlan): Int = { + plan.collect { + case InMemoryTableScanExec(_, _, relation) => + getNumInMemoryTablesRecursively(relation.child) + 1 + }.sum + } + test("withColumn doesn't invalidate cached dataframe") { var evalCount = 0 val myUDF = udf((x: String) => { evalCount += 1; "result" }) @@ -670,4 +677,138 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext assert(spark.read.parquet(path).filter($"id" > 4).count() == 15) } } + + test("SPARK-19993 simple subquery caching") { + withTempView("t1", "t2") { + Seq(1).toDF("c1").createOrReplaceTempView("t1") + Seq(2).toDF("c1").createOrReplaceTempView("t2") + + sql( + """ + |SELECT * FROM t1 + |WHERE + |NOT EXISTS (SELECT * FROM t2) + """.stripMargin).cache() + + val cachedDs = + sql( + """ + |SELECT * FROM t1 + |WHERE + |NOT EXISTS (SELECT * FROM t2) + """.stripMargin) + assert(getNumInMemoryRelations(cachedDs) == 1) + + // Additional predicate in the subquery plan should cause a cache miss + val cachedMissDs = + sql( + """ + |SELECT * FROM t1 + |WHERE + |NOT EXISTS (SELECT * FROM t2 where c1 = 0) + """.stripMargin) + assert(getNumInMemoryRelations(cachedMissDs) == 0) + } + } + + test("SPARK-19993 subquery caching with correlated predicates") { + withTempView("t1", "t2") { + Seq(1).toDF("c1").createOrReplaceTempView("t1") + Seq(1).toDF("c1").createOrReplaceTempView("t2") + + // Simple correlated predicate in subquery + sql( + """ + |SELECT * FROM t1 + |WHERE + |t1.c1 in (SELECT t2.c1 FROM t2 where t1.c1 = t2.c1) + """.stripMargin).cache() + + val cachedDs = + sql( + """ + |SELECT * FROM t1 + |WHERE + |t1.c1 in (SELECT t2.c1 FROM t2 where t1.c1 = t2.c1) + """.stripMargin) + assert(getNumInMemoryRelations(cachedDs) == 1) + } + } + + test("SPARK-19993 subquery with cached underlying relation") { + withTempView("t1") { + Seq(1).toDF("c1").createOrReplaceTempView("t1") + spark.catalog.cacheTable("t1") + + // underlying table t1 is cached as well as the query that refers to it. + val ds = + sql( + """ + |SELECT * FROM t1 + |WHERE + |NOT EXISTS (SELECT * FROM t1) + """.stripMargin) + assert(getNumInMemoryRelations(ds) == 2) + + val cachedDs = + sql( + """ + |SELECT * FROM t1 + |WHERE + |NOT EXISTS (SELECT * FROM t1) + """.stripMargin).cache() + assert(getNumInMemoryTablesRecursively(cachedDs.queryExecution.sparkPlan) == 3) + } + } + + test("SPARK-19993 nested subquery caching and scalar + predicate subqueris") { + withTempView("t1", "t2", "t3", "t4") { + Seq(1).toDF("c1").createOrReplaceTempView("t1") + Seq(2).toDF("c1").createOrReplaceTempView("t2") + Seq(1).toDF("c1").createOrReplaceTempView("t3") + Seq(1).toDF("c1").createOrReplaceTempView("t4") + + // Nested predicate subquery + sql( + """ + |SELECT * FROM t1 + |WHERE + |c1 IN (SELECT c1 FROM t2 WHERE c1 IN (SELECT c1 FROM t3 WHERE c1 = 1)) + """.stripMargin).cache() + + val cachedDs = + sql( + """ + |SELECT * FROM t1 + |WHERE + |c1 IN (SELECT c1 FROM t2 WHERE c1 IN (SELECT c1 FROM t3 WHERE c1 = 1)) + """.stripMargin) + assert(getNumInMemoryRelations(cachedDs) == 1) + + // Scalar subquery and predicate subquery + sql( + """ + |SELECT * FROM (SELECT max(c1) FROM t1 GROUP BY c1) + |WHERE + |c1 = (SELECT max(c1) FROM t2 GROUP BY c1) + |OR + |EXISTS (SELECT c1 FROM t3) + |OR + |c1 IN (SELECT c1 FROM t4) + """.stripMargin).cache() + + val cachedDs2 = + sql( + """ + |SELECT * FROM (SELECT max(c1) FROM t1 GROUP BY c1) + |WHERE + |c1 = (SELECT max(c1) FROM t2 GROUP BY c1) + |OR + |EXISTS (SELECT c1 FROM t3) + |OR + |c1 IN (SELECT c1 FROM t4) + """.stripMargin) + assert(getNumInMemoryRelations(cachedDs2) == 1) + } + } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala index fab0d7fa8482..666548d1a490 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.catalog.CatalogRelation import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.hive._ @@ -203,9 +204,9 @@ case class HiveTableScanExec( override lazy val canonicalized: HiveTableScanExec = { val input: AttributeSeq = relation.output HiveTableScanExec( - requestedAttributes.map(normalizeExprId(_, input)), + requestedAttributes.map(QueryPlan.normalizeExprId(_, input)), relation.canonicalized.asInstanceOf[CatalogRelation], - partitionPruningPred.map(normalizeExprId(_, input)))(sparkSession) + partitionPruningPred.map(QueryPlan.normalizeExprId(_, input)))(sparkSession) } override def otherCopyArgs: Seq[AnyRef] = Seq(sparkSession) From b9384382484a9f5c6b389742e7fdf63865de81c0 Mon Sep 17 00:00:00 2001 From: Lee Dongjin Date: Wed, 12 Apr 2017 09:12:14 +0100 Subject: [PATCH 265/512] [MINOR][DOCS] Fix spacings in Structured Streaming Programming Guide ## What changes were proposed in this pull request? 1. Omitted space between the sentences: `... on static data.The Spark SQL engine will ...` -> `... on static data. The Spark SQL engine will ...` 2. Omitted colon in Output Model section. ## How was this patch tested? None. Author: Lee Dongjin Closes #17564 from dongjinleekr/feature/fix-programming-guide. --- docs/structured-streaming-programming-guide.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/structured-streaming-programming-guide.md b/docs/structured-streaming-programming-guide.md index 37a1d6189a42..3cf7151819e2 100644 --- a/docs/structured-streaming-programming-guide.md +++ b/docs/structured-streaming-programming-guide.md @@ -8,7 +8,7 @@ title: Structured Streaming Programming Guide {:toc} # Overview -Structured Streaming is a scalable and fault-tolerant stream processing engine built on the Spark SQL engine. You can express your streaming computation the same way you would express a batch computation on static data.The Spark SQL engine will take care of running it incrementally and continuously and updating the final result as streaming data continues to arrive. You can use the [Dataset/DataFrame API](sql-programming-guide.html) in Scala, Java or Python to express streaming aggregations, event-time windows, stream-to-batch joins, etc. The computation is executed on the same optimized Spark SQL engine. Finally, the system ensures end-to-end exactly-once fault-tolerance guarantees through checkpointing and Write Ahead Logs. In short, *Structured Streaming provides fast, scalable, fault-tolerant, end-to-end exactly-once stream processing without the user having to reason about streaming.* +Structured Streaming is a scalable and fault-tolerant stream processing engine built on the Spark SQL engine. You can express your streaming computation the same way you would express a batch computation on static data. The Spark SQL engine will take care of running it incrementally and continuously and updating the final result as streaming data continues to arrive. You can use the [Dataset/DataFrame API](sql-programming-guide.html) in Scala, Java or Python to express streaming aggregations, event-time windows, stream-to-batch joins, etc. The computation is executed on the same optimized Spark SQL engine. Finally, the system ensures end-to-end exactly-once fault-tolerance guarantees through checkpointing and Write Ahead Logs. In short, *Structured Streaming provides fast, scalable, fault-tolerant, end-to-end exactly-once stream processing without the user having to reason about streaming.* **Structured Streaming is still ALPHA in Spark 2.1** and the APIs are still experimental. In this guide, we are going to walk you through the programming model and the APIs. First, let's start with a simple example - a streaming word count. @@ -362,7 +362,7 @@ A query on the input will generate the "Result Table". Every trigger interval (s ![Model](img/structured-streaming-model.png) -The "Output" is defined as what gets written out to the external storage. The output can be defined in different modes +The "Output" is defined as what gets written out to the external storage. The output can be defined in a different mode: - *Complete Mode* - The entire updated Result Table will be written to the external storage. It is up to the storage connector to decide how to handle writing of the entire table. From bca4259f12b32eeb156b6755d0ec5e16d8e566b3 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Wed, 12 Apr 2017 09:16:39 +0100 Subject: [PATCH 266/512] [MINOR][DOCS] JSON APIs related documentation fixes ## What changes were proposed in this pull request? This PR proposes corrections related to JSON APIs as below: - Rendering links in Python documentation - Replacing `RDD` to `Dataset` in programing guide - Adding missing description about JSON Lines consistently in `DataFrameReader.json` in Python API - De-duplicating little bit of `DataFrameReader.json` in Scala/Java API ## How was this patch tested? Manually build the documentation via `jekyll build`. Corresponding snapstops will be left on the codes. Note that currently there are Javadoc8 breaks in several places. These are proposed to be handled in https://github.com/apache/spark/pull/17477. So, this PR does not fix those. Author: hyukjinkwon Closes #17602 from HyukjinKwon/minor-json-documentation. --- docs/sql-programming-guide.md | 4 ++-- .../spark/examples/sql/JavaSQLDataSourceExample.java | 2 +- .../apache/spark/examples/sql/SQLDataSourceExample.scala | 2 +- python/pyspark/sql/readwriter.py | 8 +++++--- python/pyspark/sql/streaming.py | 4 ++-- .../main/scala/org/apache/spark/sql/DataFrameReader.scala | 4 ++-- 6 files changed, 13 insertions(+), 11 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index c425faca4c27..28942b68fa20 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -883,7 +883,7 @@ Configuration of Parquet can be done using the `setConf` method on `SparkSession
    Spark SQL can automatically infer the schema of a JSON dataset and load it as a `Dataset[Row]`. -This conversion can be done using `SparkSession.read.json()` on either an RDD of String, +This conversion can be done using `SparkSession.read.json()` on either a `Dataset[String]`, or a JSON file. Note that the file that is offered as _a json file_ is not a typical JSON file. Each @@ -897,7 +897,7 @@ For a regular multi-line JSON file, set the `wholeFile` option to `true`.
    Spark SQL can automatically infer the schema of a JSON dataset and load it as a `Dataset`. -This conversion can be done using `SparkSession.read().json()` on either an RDD of String, +This conversion can be done using `SparkSession.read().json()` on either a `Dataset`, or a JSON file. Note that the file that is offered as _a json file_ is not a typical JSON file. Each diff --git a/examples/src/main/java/org/apache/spark/examples/sql/JavaSQLDataSourceExample.java b/examples/src/main/java/org/apache/spark/examples/sql/JavaSQLDataSourceExample.java index 1a7054614b34..b66abaed6600 100644 --- a/examples/src/main/java/org/apache/spark/examples/sql/JavaSQLDataSourceExample.java +++ b/examples/src/main/java/org/apache/spark/examples/sql/JavaSQLDataSourceExample.java @@ -215,7 +215,7 @@ private static void runJsonDatasetExample(SparkSession spark) { // +------+ // Alternatively, a DataFrame can be created for a JSON dataset represented by - // an Dataset[String] storing one JSON object per string. + // a Dataset storing one JSON object per string. List jsonData = Arrays.asList( "{\"name\":\"Yin\",\"address\":{\"city\":\"Columbus\",\"state\":\"Ohio\"}}"); Dataset anotherPeopleDataset = spark.createDataset(jsonData, Encoders.STRING()); diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala b/examples/src/main/scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala index 82fd56de3984..ad74da72bd5e 100644 --- a/examples/src/main/scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala @@ -139,7 +139,7 @@ object SQLDataSourceExample { // +------+ // Alternatively, a DataFrame can be created for a JSON dataset represented by - // an Dataset[String] storing one JSON object per string + // a Dataset[String] storing one JSON object per string val otherPeopleDataset = spark.createDataset( """{"name":"Yin","address":{"city":"Columbus","state":"Ohio"}}""" :: Nil) val otherPeople = spark.read.json(otherPeopleDataset) diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index d912f395dafc..960fb882cf90 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -173,8 +173,8 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, """ Loads JSON files and returns the results as a :class:`DataFrame`. - `JSON Lines `_(newline-delimited JSON) is supported by default. - For JSON (one record per file), set the `wholeFile` parameter to ``true``. + `JSON Lines `_ (newline-delimited JSON) is supported by default. + For JSON (one record per file), set the ``wholeFile`` parameter to ``true``. If the ``schema`` parameter is not specified, this function goes through the input once to determine the input schema. @@ -634,7 +634,9 @@ def saveAsTable(self, name, format=None, mode=None, partitionBy=None, **options) @since(1.4) def json(self, path, mode=None, compression=None, dateFormat=None, timestampFormat=None): - """Saves the content of the :class:`DataFrame` in JSON format at the specified path. + """Saves the content of the :class:`DataFrame` in JSON format + (`JSON Lines text format or newline-delimited JSON `_) at the + specified path. :param path: the path in any Hadoop supported file system :param mode: specifies the behavior of the save operation when data already exists. diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index 3b604963415f..65b59d480da3 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -405,8 +405,8 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, """ Loads a JSON file stream and returns the results as a :class:`DataFrame`. - `JSON Lines `_(newline-delimited JSON) is supported by default. - For JSON (one record per file), set the `wholeFile` parameter to ``true``. + `JSON Lines `_ (newline-delimited JSON) is supported by default. + For JSON (one record per file), set the ``wholeFile`` parameter to ``true``. If the ``schema`` parameter is not specified, this function goes through the input once to determine the input schema. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 49691c15d0f7..c1b32917415a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -268,8 +268,8 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { } /** - * Loads a JSON file (JSON Lines text format or - * newline-delimited JSON) and returns the result as a `DataFrame`. + * Loads a JSON file and returns the results as a `DataFrame`. + * * See the documentation on the overloaded `json()` method with varargs for more details. * * @since 1.4.0 From 044f7ecbfd75ac5a13bfc8cd01990e195c9bd178 Mon Sep 17 00:00:00 2001 From: Brendan Dwyer Date: Wed, 12 Apr 2017 09:24:41 +0100 Subject: [PATCH 267/512] [SPARK-20298][SPARKR][MINOR] fixed spelling mistake "charactor" ## What changes were proposed in this pull request? Fixed spelling of "charactor" ## How was this patch tested? Spelling change only Author: Brendan Dwyer Closes #17611 from bdwyer2/SPARK-20298. --- R/pkg/R/DataFrame.R | 10 +++++----- R/pkg/R/SQLContext.R | 2 +- R/pkg/inst/tests/testthat/test_sparkSQL.R | 6 +++--- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index ec85f723c08c..88a138fd8eb1 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -2818,14 +2818,14 @@ setMethod("write.df", signature(df = "SparkDataFrame"), function(df, path = NULL, source = NULL, mode = "error", ...) { if (!is.null(path) && !is.character(path)) { - stop("path should be charactor, NULL or omitted.") + stop("path should be character, NULL or omitted.") } if (!is.null(source) && !is.character(source)) { stop("source should be character, NULL or omitted. It is the datasource specified ", "in 'spark.sql.sources.default' configuration by default.") } if (!is.character(mode)) { - stop("mode should be charactor or omitted. It is 'error' by default.") + stop("mode should be character or omitted. It is 'error' by default.") } if (is.null(source)) { source <- getDefaultSqlSource() @@ -3040,7 +3040,7 @@ setMethod("fillna", signature(x = "SparkDataFrame"), function(x, value, cols = NULL) { if (!(class(value) %in% c("integer", "numeric", "character", "list"))) { - stop("value should be an integer, numeric, charactor or named list.") + stop("value should be an integer, numeric, character or named list.") } if (class(value) == "list") { @@ -3052,7 +3052,7 @@ setMethod("fillna", # Check each item in the named list is of valid type lapply(value, function(v) { if (!(class(v) %in% c("integer", "numeric", "character"))) { - stop("Each item in value should be an integer, numeric or charactor.") + stop("Each item in value should be an integer, numeric or character.") } }) @@ -3598,7 +3598,7 @@ setMethod("write.stream", "in 'spark.sql.sources.default' configuration by default.") } if (!is.null(outputMode) && !is.character(outputMode)) { - stop("outputMode should be charactor or omitted.") + stop("outputMode should be character or omitted.") } if (is.null(source)) { source <- getDefaultSqlSource() diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R index c2a1e240ad39..f5c3a749fe0a 100644 --- a/R/pkg/R/SQLContext.R +++ b/R/pkg/R/SQLContext.R @@ -606,7 +606,7 @@ tableToDF <- function(tableName) { #' @note read.df since 1.4.0 read.df.default <- function(path = NULL, source = NULL, schema = NULL, na.strings = "NA", ...) { if (!is.null(path) && !is.character(path)) { - stop("path should be charactor, NULL or omitted.") + stop("path should be character, NULL or omitted.") } if (!is.null(source) && !is.character(source)) { stop("source should be character, NULL or omitted. It is the datasource specified ", diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index 58cf24256a94..3fbb618ddfc3 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -2926,9 +2926,9 @@ test_that("Call DataFrameWriter.save() API in Java without path and check argume paste("source should be character, NULL or omitted. It is the datasource specified", "in 'spark.sql.sources.default' configuration by default.")) expect_error(write.df(df, path = c(3)), - "path should be charactor, NULL or omitted.") + "path should be character, NULL or omitted.") expect_error(write.df(df, mode = TRUE), - "mode should be charactor or omitted. It is 'error' by default.") + "mode should be character or omitted. It is 'error' by default.") }) test_that("Call DataFrameWriter.load() API in Java without path and check argument types", { @@ -2947,7 +2947,7 @@ test_that("Call DataFrameWriter.load() API in Java without path and check argume # Arguments checking in R side. expect_error(read.df(path = c(3)), - "path should be charactor, NULL or omitted.") + "path should be character, NULL or omitted.") expect_error(read.df(jsonPath, source = c(1, 2)), paste("source should be character, NULL or omitted. It is the datasource specified", "in 'spark.sql.sources.default' configuration by default.")) From ffc57b0118b58de57520967d8e8730b11baad507 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 12 Apr 2017 01:30:00 -0700 Subject: [PATCH 268/512] [SPARK-20302][SQL] Short circuit cast when from and to types are structurally the same ## What changes were proposed in this pull request? When we perform a cast expression and the from and to types are structurally the same (having the same structure but different field names), we should be able to skip the actual cast. ## How was this patch tested? Added unit tests for the newly introduced functions. Author: Reynold Xin Closes #17614 from rxin/SPARK-20302. --- .../spark/sql/catalyst/expressions/Cast.scala | 65 ++++++++++++------- .../org/apache/spark/sql/types/DataType.scala | 26 ++++++++ .../sql/catalyst/expressions/CastSuite.scala | 14 ++++ .../spark/sql/types/DataTypeSuite.scala | 31 +++++++++ 4 files changed, 113 insertions(+), 23 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 1049915986d9..bb1273f5c3d8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -462,35 +462,54 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String }) } - private[this] def cast(from: DataType, to: DataType): Any => Any = to match { - case dt if dt == from => identity[Any] - case StringType => castToString(from) - case BinaryType => castToBinary(from) - case DateType => castToDate(from) - case decimal: DecimalType => castToDecimal(from, decimal) - case TimestampType => castToTimestamp(from) - case CalendarIntervalType => castToInterval(from) - case BooleanType => castToBoolean(from) - case ByteType => castToByte(from) - case ShortType => castToShort(from) - case IntegerType => castToInt(from) - case FloatType => castToFloat(from) - case LongType => castToLong(from) - case DoubleType => castToDouble(from) - case array: ArrayType => castArray(from.asInstanceOf[ArrayType].elementType, array.elementType) - case map: MapType => castMap(from.asInstanceOf[MapType], map) - case struct: StructType => castStruct(from.asInstanceOf[StructType], struct) - case udt: UserDefinedType[_] - if udt.userClass == from.asInstanceOf[UserDefinedType[_]].userClass => - identity[Any] - case _: UserDefinedType[_] => - throw new SparkException(s"Cannot cast $from to $to.") + private[this] def cast(from: DataType, to: DataType): Any => Any = { + // If the cast does not change the structure, then we don't really need to cast anything. + // We can return what the children return. Same thing should happen in the codegen path. + if (DataType.equalsStructurally(from, to)) { + identity + } else { + to match { + case dt if dt == from => identity[Any] + case StringType => castToString(from) + case BinaryType => castToBinary(from) + case DateType => castToDate(from) + case decimal: DecimalType => castToDecimal(from, decimal) + case TimestampType => castToTimestamp(from) + case CalendarIntervalType => castToInterval(from) + case BooleanType => castToBoolean(from) + case ByteType => castToByte(from) + case ShortType => castToShort(from) + case IntegerType => castToInt(from) + case FloatType => castToFloat(from) + case LongType => castToLong(from) + case DoubleType => castToDouble(from) + case array: ArrayType => + castArray(from.asInstanceOf[ArrayType].elementType, array.elementType) + case map: MapType => castMap(from.asInstanceOf[MapType], map) + case struct: StructType => castStruct(from.asInstanceOf[StructType], struct) + case udt: UserDefinedType[_] + if udt.userClass == from.asInstanceOf[UserDefinedType[_]].userClass => + identity[Any] + case _: UserDefinedType[_] => + throw new SparkException(s"Cannot cast $from to $to.") + } + } } private[this] lazy val cast: Any => Any = cast(child.dataType, dataType) protected override def nullSafeEval(input: Any): Any = cast(input) + override def genCode(ctx: CodegenContext): ExprCode = { + // If the cast does not change the structure, then we don't really need to cast anything. + // We can return what the children return. Same thing should happen in the interpreted path. + if (DataType.equalsStructurally(child.dataType, dataType)) { + child.genCode(ctx) + } else { + super.genCode(ctx) + } + } + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val eval = child.genCode(ctx) val nullSafeCast = nullSafeCastFunction(child.dataType, dataType, ctx) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala index 520aff5e2b67..30745c6a9d42 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -288,4 +288,30 @@ object DataType { case (fromDataType, toDataType) => fromDataType == toDataType } } + + /** + * Returns true if the two data types share the same "shape", i.e. the types (including + * nullability) are the same, but the field names don't need to be the same. + */ + def equalsStructurally(from: DataType, to: DataType): Boolean = { + (from, to) match { + case (left: ArrayType, right: ArrayType) => + equalsStructurally(left.elementType, right.elementType) && + left.containsNull == right.containsNull + + case (left: MapType, right: MapType) => + equalsStructurally(left.keyType, right.keyType) && + equalsStructurally(left.valueType, right.valueType) && + left.valueContainsNull == right.valueContainsNull + + case (StructType(fromFields), StructType(toFields)) => + fromFields.length == toFields.length && + fromFields.zip(toFields) + .forall { case (l, r) => + equalsStructurally(l.dataType, r.dataType) && l.nullable == r.nullable + } + + case (fromDataType, toDataType) => fromDataType == toDataType + } + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index 8eccadbdd8af..a7ffa884d228 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -813,4 +813,18 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { assert(cast(1.0.toFloat, DateType).checkInputDataTypes().isFailure) assert(cast(1.0, DateType).checkInputDataTypes().isFailure) } + + test("SPARK-20302 cast with same structure") { + val from = new StructType() + .add("a", IntegerType) + .add("b", new StructType().add("b1", LongType)) + + val to = new StructType() + .add("a1", IntegerType) + .add("b1", new StructType().add("b11", LongType)) + + val input = Row(10, Row(12L)) + + checkEvaluation(cast(Literal.create(input, from), to), input) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala index f078ef013387..c4635c8f126a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala @@ -411,4 +411,35 @@ class DataTypeSuite extends SparkFunSuite { checkCatalogString(ArrayType(createStruct(40))) checkCatalogString(MapType(IntegerType, StringType)) checkCatalogString(MapType(IntegerType, createStruct(40))) + + def checkEqualsStructurally(from: DataType, to: DataType, expected: Boolean): Unit = { + val testName = s"equalsStructurally: (from: $from, to: $to)" + test(testName) { + assert(DataType.equalsStructurally(from, to) === expected) + } + } + + checkEqualsStructurally(BooleanType, BooleanType, true) + checkEqualsStructurally(IntegerType, IntegerType, true) + checkEqualsStructurally(IntegerType, LongType, false) + checkEqualsStructurally(ArrayType(IntegerType, true), ArrayType(IntegerType, true), true) + checkEqualsStructurally(ArrayType(IntegerType, true), ArrayType(IntegerType, false), false) + + checkEqualsStructurally( + new StructType().add("f1", IntegerType), + new StructType().add("f2", IntegerType), + true) + checkEqualsStructurally( + new StructType().add("f1", IntegerType), + new StructType().add("f2", IntegerType, false), + false) + + checkEqualsStructurally( + new StructType().add("f1", IntegerType).add("f", new StructType().add("f2", StringType)), + new StructType().add("f2", IntegerType).add("g", new StructType().add("f1", StringType)), + true) + checkEqualsStructurally( + new StructType().add("f1", IntegerType).add("f", new StructType().add("f2", StringType, false)), + new StructType().add("f2", IntegerType).add("g", new StructType().add("f1", StringType)), + false) } From 2e1fd46e12bf948490ece2caa73d227b6a924a14 Mon Sep 17 00:00:00 2001 From: jtoka Date: Wed, 12 Apr 2017 11:36:08 +0100 Subject: [PATCH 269/512] [SPARK-20296][TRIVIAL][DOCS] Count distinct error message for streaming ## What changes were proposed in this pull request? Update count distinct error message for streaming datasets/dataframes to match current behavior. These aggregations are not yet supported, regardless of whether the dataset/dataframe is aggregated. Author: jtoka Closes #17609 from jtoka/master. --- .../sql/catalyst/analysis/UnsupportedOperationChecker.scala | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala index 7da7f55aa5d7..3f76f26dbe4e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala @@ -139,9 +139,8 @@ object UnsupportedOperationChecker { } throwErrorIf( child.isStreaming && distinctAggExprs.nonEmpty, - "Distinct aggregations are not supported on streaming DataFrames/Datasets, unless " + - "it is on aggregated DataFrame/Dataset in Complete output mode. Consider using " + - "approximate distinct aggregation (e.g. approx_count_distinct() instead of count()).") + "Distinct aggregations are not supported on streaming DataFrames/Datasets. Consider " + + "using approx_count_distinct() instead.") case _: Command => throwError("Commands like CreateTable*, AlterTable*, Show* are not supported with " + From ceaf77ae43a14e993ac6d1ff34b50256eacd6abb Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Wed, 12 Apr 2017 12:38:48 +0100 Subject: [PATCH 270/512] [SPARK-18692][BUILD][DOCS] Test Java 8 unidoc build on Jenkins ## What changes were proposed in this pull request? This PR proposes to run Spark unidoc to test Javadoc 8 build as Javadoc 8 is easily re-breakable. There are several problems with it: - It introduces little extra bit of time to run the tests. In my case, it took 1.5 mins more (`Elapsed :[94.8746569157]`). How it was tested is described in "How was this patch tested?". - > One problem that I noticed was that Unidoc appeared to be processing test sources: if we can find a way to exclude those from being processed in the first place then that might significantly speed things up. (see joshrosen's [comment](https://issues.apache.org/jira/browse/SPARK-18692?focusedCommentId=15947627&page=com.atlassian.jira.plugin.system.issuetabpanels:comment-tabpanel#comment-15947627)) To complete this automated build, It also suggests to fix existing Javadoc breaks / ones introduced by test codes as described above. There fixes are similar instances that previously fixed. Please refer https://github.com/apache/spark/pull/15999 and https://github.com/apache/spark/pull/16013 Note that this only fixes **errors** not **warnings**. Please see my observation https://github.com/apache/spark/pull/17389#issuecomment-288438704 for spurious errors by warnings. ## How was this patch tested? Manually via `jekyll build` for building tests. Also, tested via running `./dev/run-tests`. This was tested via manually adding `time.time()` as below: ```diff profiles_and_goals = build_profiles + sbt_goals print("[info] Building Spark unidoc (w/Hive 1.2.1) using SBT with these arguments: ", " ".join(profiles_and_goals)) + import time + st = time.time() exec_sbt(profiles_and_goals) + print("Elapsed :[%s]" % str(time.time() - st)) ``` produces ``` ... ======================================================================== Building Unidoc API Documentation ======================================================================== ... [info] Main Java API documentation successful. ... Elapsed :[94.8746569157] ... Author: hyukjinkwon Closes #17477 from HyukjinKwon/SPARK-18692. --- .../org/apache/spark/rpc/RpcEndpoint.scala | 10 +++++----- .../org/apache/spark/rpc/RpcTimeout.scala | 2 +- .../apache/spark/scheduler/DAGScheduler.scala | 4 ++-- .../scheduler/ExternalClusterManager.scala | 2 +- .../spark/scheduler/TaskSchedulerImpl.scala | 8 ++++---- .../apache/spark/storage/BlockManager.scala | 2 +- .../org/apache/spark/AccumulatorSuite.scala | 4 ++-- .../spark/ExternalShuffleServiceSuite.scala | 2 +- .../org/apache/spark/LocalSparkContext.scala | 2 +- .../scheduler/SchedulerIntegrationSuite.scala | 4 ++-- .../serializer/SerializerPropertiesSuite.scala | 2 +- dev/run-tests.py | 15 +++++++++++++++ .../spark/graphx/LocalSparkContext.scala | 2 +- .../spark/ml/classification/Classifier.scala | 2 +- .../org/apache/spark/ml/PipelineSuite.scala | 8 ++++++-- .../org/apache/spark/ml/feature/LSHTest.scala | 12 ++++++++---- .../apache/spark/ml/param/ParamsSuite.scala | 2 +- .../apache/spark/ml/tree/impl/TreeTests.scala | 6 ++++-- .../spark/ml/util/DefaultReadWriteTest.scala | 18 +++++++++--------- .../apache/spark/ml/util/StopwatchSuite.scala | 4 ++-- .../apache/spark/ml/util/TempDirectory.scala | 4 +++- .../spark/mllib/tree/ImpuritySuite.scala | 2 +- .../mllib/util/MLlibTestSparkContext.scala | 2 +- .../cluster/mesos/MesosSchedulerUtils.scala | 6 +++--- .../apache/spark/sql/RandomDataGenerator.scala | 8 ++++---- .../spark/sql/UnsafeProjectionBenchmark.scala | 2 +- .../org/apache/spark/sql/catalog/Catalog.scala | 2 +- .../DatasetSerializerRegistratorSuite.scala | 4 +++- .../sql/streaming/FileStreamSourceSuite.scala | 4 ++-- .../spark/sql/streaming/StreamSuite.scala | 4 ++-- .../sql/streaming/StreamingQuerySuite.scala | 12 ++++++++---- .../apache/spark/sql/test/SQLTestUtils.scala | 18 ++++++++++-------- .../apache/spark/sql/test/TestSQLContext.scala | 2 +- .../java/org/apache/hive/service/Service.java | 2 +- .../apache/hive/service/ServiceOperations.java | 12 ++++++------ .../hive/service/auth/HttpAuthUtils.java | 2 +- .../auth/PasswdAuthenticationProvider.java | 2 +- .../service/auth/TSetIpAddressProcessor.java | 9 +++------ .../hive/service/cli/CLIServiceUtils.java | 2 +- .../cli/operation/ClassicTableTypeMapping.java | 6 +++--- .../cli/operation/TableTypeMapping.java | 2 +- .../service/cli/session/SessionManager.java | 4 +++- .../ThreadFactoryWithGarbageCleanup.java | 6 +++--- .../apache/spark/sql/hive/HiveInspectors.scala | 4 ++-- .../sql/hive/execution/HiveQueryFileTest.scala | 2 +- .../apache/spark/sql/hive/orc/OrcTest.scala | 4 ++-- .../spark/streaming/rdd/MapWithStateRDD.scala | 4 ++-- .../scheduler/rate/PIDRateEstimator.scala | 2 +- .../scheduler/rate/RateEstimator.scala | 2 +- 49 files changed, 140 insertions(+), 106 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala index 0ba95169529e..97eed540b8f5 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala @@ -35,7 +35,7 @@ private[spark] trait RpcEnvFactory { * * The life-cycle of an endpoint is: * - * constructor -> onStart -> receive* -> onStop + * {@code constructor -> onStart -> receive* -> onStop} * * Note: `receive` can be called concurrently. If you want `receive` to be thread-safe, please use * [[ThreadSafeRpcEndpoint]] @@ -63,16 +63,16 @@ private[spark] trait RpcEndpoint { } /** - * Process messages from [[RpcEndpointRef.send]] or [[RpcCallContext.reply)]]. If receiving a - * unmatched message, [[SparkException]] will be thrown and sent to `onError`. + * Process messages from `RpcEndpointRef.send` or `RpcCallContext.reply`. If receiving a + * unmatched message, `SparkException` will be thrown and sent to `onError`. */ def receive: PartialFunction[Any, Unit] = { case _ => throw new SparkException(self + " does not implement 'receive'") } /** - * Process messages from [[RpcEndpointRef.ask]]. If receiving a unmatched message, - * [[SparkException]] will be thrown and sent to `onError`. + * Process messages from `RpcEndpointRef.ask`. If receiving a unmatched message, + * `SparkException` will be thrown and sent to `onError`. */ def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case _ => context.sendFailure(new SparkException(self + " won't reply anything")) diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcTimeout.scala b/core/src/main/scala/org/apache/spark/rpc/RpcTimeout.scala index 2c9a976e7693..0557b7a3cc0b 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcTimeout.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcTimeout.scala @@ -26,7 +26,7 @@ import org.apache.spark.SparkConf import org.apache.spark.util.{ThreadUtils, Utils} /** - * An exception thrown if RpcTimeout modifies a [[TimeoutException]]. + * An exception thrown if RpcTimeout modifies a `TimeoutException`. */ private[rpc] class RpcTimeoutException(message: String, cause: TimeoutException) extends TimeoutException(message) { initCause(cause) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 09717316833a..aab177f257a8 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -607,7 +607,7 @@ class DAGScheduler( * @param resultHandler callback to pass each result to * @param properties scheduler properties to attach to this job, e.g. fair scheduler pool name * - * @throws Exception when the job fails + * @note Throws `Exception` when the job fails */ def runJob[T, U]( rdd: RDD[T], @@ -644,7 +644,7 @@ class DAGScheduler( * * @param rdd target RDD to run tasks on * @param func a function to run on each partition of the RDD - * @param evaluator [[ApproximateEvaluator]] to receive the partial results + * @param evaluator `ApproximateEvaluator` to receive the partial results * @param callSite where in the user program this job was called * @param timeout maximum time to wait for the job, in milliseconds * @param properties scheduler properties to attach to this job, e.g. fair scheduler pool name diff --git a/core/src/main/scala/org/apache/spark/scheduler/ExternalClusterManager.scala b/core/src/main/scala/org/apache/spark/scheduler/ExternalClusterManager.scala index d1ac7131baba..47f3527a32c0 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ExternalClusterManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ExternalClusterManager.scala @@ -42,7 +42,7 @@ private[spark] trait ExternalClusterManager { /** * Create a scheduler backend for the given SparkContext and scheduler. This is - * called after task scheduler is created using [[ExternalClusterManager.createTaskScheduler()]]. + * called after task scheduler is created using `ExternalClusterManager.createTaskScheduler()`. * @param sc SparkContext * @param masterURL the master URL * @param scheduler TaskScheduler that will be used with the scheduler backend. diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index c849a16023a7..1b6bc9139f9c 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -38,7 +38,7 @@ import org.apache.spark.util.{AccumulatorV2, ThreadUtils, Utils} /** * Schedules tasks for multiple types of clusters by acting through a SchedulerBackend. - * It can also work with a local setup by using a [[LocalSchedulerBackend]] and setting + * It can also work with a local setup by using a `LocalSchedulerBackend` and setting * isLocal to true. It handles common logic, like determining a scheduling order across jobs, waking * up to launch speculative tasks, etc. * @@ -704,12 +704,12 @@ private[spark] object TaskSchedulerImpl { * Used to balance containers across hosts. * * Accepts a map of hosts to resource offers for that host, and returns a prioritized list of - * resource offers representing the order in which the offers should be used. The resource + * resource offers representing the order in which the offers should be used. The resource * offers are ordered such that we'll allocate one container on each host before allocating a * second container on any host, and so on, in order to reduce the damage if a host fails. * - * For example, given , , , returns - * [o1, o5, o4, 02, o6, o3] + * For example, given {@literal }, {@literal } and + * {@literal }, returns {@literal [o1, o5, o4, o2, o6, o3]}. */ def prioritizeContainers[K, T] (map: HashMap[K, ArrayBuffer[T]]): List[T] = { val _keyList = new ArrayBuffer[K](map.size) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 63acba65d3c5..3219969bcd06 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -66,7 +66,7 @@ private[spark] trait BlockData { /** * Returns a Netty-friendly wrapper for the block's data. * - * @see [[ManagedBuffer#convertToNetty()]] + * Please see `ManagedBuffer.convertToNetty()` for more details. */ def toNetty(): Object diff --git a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala index 6d03ee091e4e..ddbcb2d19dcb 100644 --- a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala @@ -243,7 +243,7 @@ private[spark] object AccumulatorSuite { import InternalAccumulator._ /** - * Create a long accumulator and register it to [[AccumulatorContext]]. + * Create a long accumulator and register it to `AccumulatorContext`. */ def createLongAccum( name: String, @@ -258,7 +258,7 @@ private[spark] object AccumulatorSuite { } /** - * Make an [[AccumulableInfo]] out of an [[Accumulable]] with the intent to use the + * Make an `AccumulableInfo` out of an [[Accumulable]] with the intent to use the * info as an accumulator update. */ def makeInfo(a: AccumulatorV2[_, _]): AccumulableInfo = a.toInfo(Some(a.value), None) diff --git a/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala b/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala index eb3fb99747d1..fe944031bc94 100644 --- a/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala +++ b/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala @@ -27,7 +27,7 @@ import org.apache.spark.network.shuffle.{ExternalShuffleBlockHandler, ExternalSh /** * This suite creates an external shuffle server and routes all shuffle fetches through it. * Note that failures in this suite may arise due to changes in Spark that invalidate expectations - * set up in [[ExternalShuffleBlockHandler]], such as changing the format of shuffle files or how + * set up in `ExternalShuffleBlockHandler`, such as changing the format of shuffle files or how * we hash files into folders. */ class ExternalShuffleServiceSuite extends ShuffleSuite with BeforeAndAfterAll { diff --git a/core/src/test/scala/org/apache/spark/LocalSparkContext.scala b/core/src/test/scala/org/apache/spark/LocalSparkContext.scala index 24ec99c7e5e6..1dd89bcbe36b 100644 --- a/core/src/test/scala/org/apache/spark/LocalSparkContext.scala +++ b/core/src/test/scala/org/apache/spark/LocalSparkContext.scala @@ -22,7 +22,7 @@ import org.scalatest.BeforeAndAfterAll import org.scalatest.BeforeAndAfterEach import org.scalatest.Suite -/** Manages a local `sc` {@link SparkContext} variable, correctly stopping it after each test. */ +/** Manages a local `sc` `SparkContext` variable, correctly stopping it after each test. */ trait LocalSparkContext extends BeforeAndAfterEach with BeforeAndAfterAll { self: Suite => @transient var sc: SparkContext = _ diff --git a/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala index 8103983c4392..8300607ea888 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala @@ -95,12 +95,12 @@ abstract class SchedulerIntegrationSuite[T <: MockBackend: ClassTag] extends Spa } /** - * A map from partition -> results for all tasks of a job when you call this test framework's + * A map from partition to results for all tasks of a job when you call this test framework's * [[submit]] method. Two important considerations: * * 1. If there is a job failure, results may or may not be empty. If any tasks succeed before * the job has failed, they will get included in `results`. Instead, check for job failure by - * checking [[failure]]. (Also see [[assertDataStructuresEmpty()]]) + * checking [[failure]]. (Also see `assertDataStructuresEmpty()`) * * 2. This only gets cleared between tests. So you'll need to do special handling if you submit * more than one job in one test. diff --git a/core/src/test/scala/org/apache/spark/serializer/SerializerPropertiesSuite.scala b/core/src/test/scala/org/apache/spark/serializer/SerializerPropertiesSuite.scala index 4ce3b941bea5..99882bf76e29 100644 --- a/core/src/test/scala/org/apache/spark/serializer/SerializerPropertiesSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/SerializerPropertiesSuite.scala @@ -29,7 +29,7 @@ import org.apache.spark.serializer.KryoTest.RegistratorWithoutAutoReset /** * Tests to ensure that [[Serializer]] implementations obey the API contracts for methods that * describe properties of the serialized stream, such as - * [[Serializer.supportsRelocationOfSerializedObjects]]. + * `Serializer.supportsRelocationOfSerializedObjects`. */ class SerializerPropertiesSuite extends SparkFunSuite { diff --git a/dev/run-tests.py b/dev/run-tests.py index 04035b33e6a6..450b68123e1f 100755 --- a/dev/run-tests.py +++ b/dev/run-tests.py @@ -344,6 +344,19 @@ def build_spark_sbt(hadoop_version): exec_sbt(profiles_and_goals) +def build_spark_unidoc_sbt(hadoop_version): + set_title_and_block("Building Unidoc API Documentation", "BLOCK_DOCUMENTATION") + # Enable all of the profiles for the build: + build_profiles = get_hadoop_profiles(hadoop_version) + modules.root.build_profile_flags + sbt_goals = ["unidoc"] + profiles_and_goals = build_profiles + sbt_goals + + print("[info] Building Spark unidoc (w/Hive 1.2.1) using SBT with these arguments: ", + " ".join(profiles_and_goals)) + + exec_sbt(profiles_and_goals) + + def build_spark_assembly_sbt(hadoop_version): # Enable all of the profiles for the build: build_profiles = get_hadoop_profiles(hadoop_version) + modules.root.build_profile_flags @@ -352,6 +365,8 @@ def build_spark_assembly_sbt(hadoop_version): print("[info] Building Spark assembly (w/Hive 1.2.1) using SBT with these arguments: ", " ".join(profiles_and_goals)) exec_sbt(profiles_and_goals) + # Make sure that Java and Scala API documentation can be generated + build_spark_unidoc_sbt(hadoop_version) def build_apache_spark(build_tool, hadoop_version): diff --git a/graphx/src/test/scala/org/apache/spark/graphx/LocalSparkContext.scala b/graphx/src/test/scala/org/apache/spark/graphx/LocalSparkContext.scala index d2ad9be55577..66c4747fec26 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/LocalSparkContext.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/LocalSparkContext.scala @@ -21,7 +21,7 @@ import org.apache.spark.SparkConf import org.apache.spark.SparkContext /** - * Provides a method to run tests against a {@link SparkContext} variable that is correctly stopped + * Provides a method to run tests against a `SparkContext` variable that is correctly stopped * after each test. */ trait LocalSparkContext { diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala index d8608d885d6f..bc0b49d48d32 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala @@ -74,7 +74,7 @@ abstract class Classifier[ * and features (`Vector`). * @param numClasses Number of classes label can take. Labels must be integers in the range * [0, numClasses). - * @throws SparkException if any label is not an integer >= 0 + * @note Throws `SparkException` if any label is a non-integer or is negative */ protected def extractLabeledPoints(dataset: Dataset[_], numClasses: Int): RDD[LabeledPoint] = { require(numClasses > 0, s"Classifier (in extractLabeledPoints) found numClasses =" + diff --git a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala index 4cdbf845ae4f..4a7e4dd80f24 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala @@ -230,7 +230,9 @@ class PipelineSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul } -/** Used to test [[Pipeline]] with [[MLWritable]] stages */ +/** + * Used to test [[Pipeline]] with `MLWritable` stages + */ class WritableStage(override val uid: String) extends Transformer with MLWritable { final val intParam: IntParam = new IntParam(this, "intParam", "doc") @@ -257,7 +259,9 @@ object WritableStage extends MLReadable[WritableStage] { override def load(path: String): WritableStage = super.load(path) } -/** Used to test [[Pipeline]] with non-[[MLWritable]] stages */ +/** + * Used to test [[Pipeline]] with non-`MLWritable` stages + */ class UnWritableStage(override val uid: String) extends Transformer { final val intParam: IntParam = new IntParam(this, "intParam", "doc") diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/LSHTest.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/LSHTest.scala index dd4dd62b8cfe..db4f56ed60d3 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/LSHTest.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/LSHTest.scala @@ -29,8 +29,10 @@ private[ml] object LSHTest { * the following property is satisfied. * * There exist dist1, dist2, p1, p2, so that for any two elements e1 and e2, - * If dist(e1, e2) <= dist1, then Pr{h(x) == h(y)} >= p1 - * If dist(e1, e2) >= dist2, then Pr{h(x) == h(y)} <= p2 + * If dist(e1, e2) is less than or equal to dist1, then Pr{h(x) == h(y)} is greater than + * or equal to p1 + * If dist(e1, e2) is greater than or equal to dist2, then Pr{h(x) == h(y)} is less than + * or equal to p2 * * This is called locality sensitive property. This method checks the property on an * existing dataset and calculate the probabilities. @@ -38,8 +40,10 @@ private[ml] object LSHTest { * * This method hashes each elements to hash buckets using LSH, and calculate the false positive * and false negative: - * False positive: Of all (e1, e2) sharing any bucket, the probability of dist(e1, e2) > distFP - * False negative: Of all (e1, e2) not sharing buckets, the probability of dist(e1, e2) < distFN + * False positive: Of all (e1, e2) sharing any bucket, the probability of dist(e1, e2) is greater + * than distFP + * False negative: Of all (e1, e2) not sharing buckets, the probability of dist(e1, e2) is less + * than distFN * * @param dataset The dataset to verify the locality sensitive hashing property. * @param lsh The lsh instance to perform the hashing diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala index aa9c53ca30ee..78a33e05e0e4 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala @@ -377,7 +377,7 @@ class ParamsSuite extends SparkFunSuite { object ParamsSuite extends SparkFunSuite { /** - * Checks common requirements for [[Params.params]]: + * Checks common requirements for `Params.params`: * - params are ordered by names * - param parent has the same UID as the object's UID * - param name is the same as the param method name diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala index c90cb8ca1034..92a236928e90 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala @@ -34,7 +34,7 @@ private[ml] object TreeTests extends SparkFunSuite { * Convert the given data to a DataFrame, and set the features and label metadata. * @param data Dataset. Categorical features and labels must already have 0-based indices. * This must be non-empty. - * @param categoricalFeatures Map: categorical feature index -> number of distinct values + * @param categoricalFeatures Map: categorical feature index to number of distinct values * @param numClasses Number of classes label can take. If 0, mark as continuous. * @return DataFrame with metadata */ @@ -69,7 +69,9 @@ private[ml] object TreeTests extends SparkFunSuite { df("label").as("label", labelMetadata)) } - /** Java-friendly version of [[setMetadata()]] */ + /** + * Java-friendly version of `setMetadata()` + */ def setMetadata( data: JavaRDD[LabeledPoint], categoricalFeatures: java.util.Map[java.lang.Integer, java.lang.Integer], diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala index bfe8f12258bb..27d606cb05dc 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala @@ -81,20 +81,20 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite => /** * Default test for Estimator, Model pairs: * - Explicitly set Params, and train model - * - Test save/load using [[testDefaultReadWrite()]] on Estimator and Model + * - Test save/load using `testDefaultReadWrite` on Estimator and Model * - Check Params on Estimator and Model * - Compare model data * - * This requires that [[Model]]'s [[Param]]s should be a subset of [[Estimator]]'s [[Param]]s. + * This requires that `Model`'s `Param`s should be a subset of `Estimator`'s `Param`s. * * @param estimator Estimator to test - * @param dataset Dataset to pass to [[Estimator.fit()]] - * @param testEstimatorParams Set of [[Param]] values to set in estimator - * @param testModelParams Set of [[Param]] values to set in model - * @param checkModelData Method which takes the original and loaded [[Model]] and compares their - * data. This method does not need to check [[Param]] values. - * @tparam E Type of [[Estimator]] - * @tparam M Type of [[Model]] produced by estimator + * @param dataset Dataset to pass to `Estimator.fit()` + * @param testEstimatorParams Set of `Param` values to set in estimator + * @param testModelParams Set of `Param` values to set in model + * @param checkModelData Method which takes the original and loaded `Model` and compares their + * data. This method does not need to check `Param` values. + * @tparam E Type of `Estimator` + * @tparam M Type of `Model` produced by estimator */ def testEstimatorAndModelReadWrite[ E <: Estimator[M] with MLWritable, M <: Model[M] with MLWritable]( diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/StopwatchSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/util/StopwatchSuite.scala index 141249a427a4..54e363a8b9f2 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/StopwatchSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/StopwatchSuite.scala @@ -105,8 +105,8 @@ class StopwatchSuite extends SparkFunSuite with MLlibTestSparkContext { private object StopwatchSuite extends SparkFunSuite { /** - * Checks the input stopwatch on a task that takes a random time (<10ms) to finish. Validates and - * returns the duration reported by the stopwatch. + * Checks the input stopwatch on a task that takes a random time (less than 10ms) to finish. + * Validates and returns the duration reported by the stopwatch. */ def checkStopwatch(sw: Stopwatch): Long = { val ubStart = now diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/TempDirectory.scala b/mllib/src/test/scala/org/apache/spark/ml/util/TempDirectory.scala index 8f11bbc8e47a..50b73e0e99a2 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/TempDirectory.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/TempDirectory.scala @@ -30,7 +30,9 @@ trait TempDirectory extends BeforeAndAfterAll { self: Suite => private var _tempDir: File = _ - /** Returns the temporary directory as a [[File]] instance. */ + /** + * Returns the temporary directory as a `File` instance. + */ protected def tempDir: File = _tempDir override def beforeAll(): Unit = { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/ImpuritySuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/ImpuritySuite.scala index 14152cdd63bc..d0f02dd966bd 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/ImpuritySuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/ImpuritySuite.scala @@ -21,7 +21,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.tree.impurity.{EntropyAggregator, GiniAggregator} /** - * Test suites for [[GiniAggregator]] and [[EntropyAggregator]]. + * Test suites for `GiniAggregator` and `EntropyAggregator`. */ class ImpuritySuite extends SparkFunSuite { test("Gini impurity does not support negative labels") { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala index 6bb7ed9c9513..720237bd2ddd 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala @@ -60,7 +60,7 @@ trait MLlibTestSparkContext extends TempDirectory { self: Suite => * A helper object for importing SQL implicits. * * Note that the alternative of importing `spark.implicits._` is not possible here. - * This is because we create the [[SQLContext]] immediately before the first test is run, + * This is because we create the `SQLContext` immediately before the first test is run, * but the implicits import is needed in the constructor. */ protected object testImplicits extends SQLImplicits { diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala index 3f25535cb5ec..9d81025a3016 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala @@ -239,7 +239,7 @@ trait MesosSchedulerUtils extends Logging { } /** - * Converts the attributes from the resource offer into a Map of name -> Attribute Value + * Converts the attributes from the resource offer into a Map of name to Attribute Value * The attribute values are the mesos attribute types and they are * * @param offerAttributes the attributes offered @@ -296,7 +296,7 @@ trait MesosSchedulerUtils extends Logging { /** * Parses the attributes constraints provided to spark and build a matching data struct: - * Map[, Set[values-to-match]] + * {@literal Map[, Set[values-to-match]} * The constraints are specified as ';' separated key-value pairs where keys and values * are separated by ':'. The ':' implies equality (for singular values) and "is one of" for * multiple values (comma separated). For example: @@ -354,7 +354,7 @@ trait MesosSchedulerUtils extends Logging { * container overheads. * * @param sc SparkContext to use to get `spark.mesos.executor.memoryOverhead` value - * @return memory requirement as (0.1 * ) or MEMORY_OVERHEAD_MINIMUM + * @return memory requirement as (0.1 * memoryOverhead) or MEMORY_OVERHEAD_MINIMUM * (whichever is larger) */ def executorMemory(sc: SparkContext): Int = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala index 850869799507..8ae3ff5043e6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala @@ -117,11 +117,11 @@ object RandomDataGenerator { } /** - * Returns a function which generates random values for the given [[DataType]], or `None` if no + * Returns a function which generates random values for the given `DataType`, or `None` if no * random data generator is defined for that data type. The generated values will use an external - * representation of the data type; for example, the random generator for [[DateType]] will return - * instances of [[java.sql.Date]] and the generator for [[StructType]] will return a [[Row]]. - * For a [[UserDefinedType]] for a class X, an instance of class X is returned. + * representation of the data type; for example, the random generator for `DateType` will return + * instances of [[java.sql.Date]] and the generator for `StructType` will return a [[Row]]. + * For a `UserDefinedType` for a class X, an instance of class X is returned. * * @param dataType the type to generate values for * @param nullable whether null values should be generated diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/UnsafeProjectionBenchmark.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/UnsafeProjectionBenchmark.scala index a6d90409382e..769addf3b29e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/UnsafeProjectionBenchmark.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/UnsafeProjectionBenchmark.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.types._ import org.apache.spark.util.Benchmark /** - * Benchmark [[UnsafeProjection]] for fixed-length/primitive-type fields. + * Benchmark `UnsafeProjection` for fixed-length/primitive-type fields. */ object UnsafeProjectionBenchmark { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala index 074952ff7900..7e5da012f84c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala @@ -510,7 +510,7 @@ abstract class Catalog { def refreshTable(tableName: String): Unit /** - * Invalidates and refreshes all the cached data (and the associated metadata) for any [[Dataset]] + * Invalidates and refreshes all the cached data (and the associated metadata) for any `Dataset` * that contains the given data source path. Path matching is by prefix, i.e. "/" would invalidate * everything that is cached. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSerializerRegistratorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSerializerRegistratorSuite.scala index 0f3d0cefe3bb..92c5656f65bb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSerializerRegistratorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSerializerRegistratorSuite.scala @@ -56,7 +56,9 @@ object TestRegistrator { def apply(): TestRegistrator = new TestRegistrator() } -/** A [[Serializer]] that takes a [[KryoData]] and serializes it as KryoData(0). */ +/** + * A `Serializer` that takes a [[KryoData]] and serializes it as KryoData(0). + */ class ZeroKryoDataSerializer extends Serializer[KryoData] { override def write(kryo: Kryo, output: Output, t: KryoData): Unit = { output.writeInt(0) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala index 26967782f77c..2108b118bf05 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala @@ -44,8 +44,8 @@ abstract class FileStreamSourceTest import testImplicits._ /** - * A subclass [[AddData]] for adding data to files. This is meant to use the - * [[FileStreamSource]] actually being used in the execution. + * A subclass `AddData` for adding data to files. This is meant to use the + * `FileStreamSource` actually being used in the execution. */ abstract class AddFileData extends AddData { override def addData(query: Option[StreamExecution]): (Source, Offset) = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala index 5ab9dc2bc776..13fe51a55773 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala @@ -569,7 +569,7 @@ class ThrowingIOExceptionLikeHadoop12074 extends FakeSource { object ThrowingIOExceptionLikeHadoop12074 { /** - * A latch to allow the user to wait until [[ThrowingIOExceptionLikeHadoop12074.createSource]] is + * A latch to allow the user to wait until `ThrowingIOExceptionLikeHadoop12074.createSource` is * called. */ @volatile var createSourceLatch: CountDownLatch = null @@ -600,7 +600,7 @@ class ThrowingInterruptedIOException extends FakeSource { object ThrowingInterruptedIOException { /** - * A latch to allow the user to wait until [[ThrowingInterruptedIOException.createSource]] is + * A latch to allow the user to wait until `ThrowingInterruptedIOException.createSource` is * called. */ @volatile var createSourceLatch: CountDownLatch = null diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala index 2ebbfcd22b97..b69536ed3746 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala @@ -642,8 +642,10 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi * * @param expectedBehavior Expected behavior (not blocked, blocked, or exception thrown) * @param timeoutMs Timeout in milliseconds - * When timeoutMs <= 0, awaitTermination() is tested (i.e. w/o timeout) - * When timeoutMs > 0, awaitTermination(timeoutMs) is tested + * When timeoutMs is less than or equal to 0, awaitTermination() is + * tested (i.e. w/o timeout) + * When timeoutMs is greater than 0, awaitTermination(timeoutMs) is + * tested * @param expectedReturnValue Expected return value when awaitTermination(timeoutMs) is used */ case class TestAwaitTermination( @@ -667,8 +669,10 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi * * @param expectedBehavior Expected behavior (not blocked, blocked, or exception thrown) * @param timeoutMs Timeout in milliseconds - * When timeoutMs <= 0, awaitTermination() is tested (i.e. w/o timeout) - * When timeoutMs > 0, awaitTermination(timeoutMs) is tested + * When timeoutMs is less than or equal to 0, awaitTermination() is + * tested (i.e. w/o timeout) + * When timeoutMs is greater than 0, awaitTermination(timeoutMs) is + * tested * @param expectedReturnValue Expected return value when awaitTermination(timeoutMs) is used */ def assertOnQueryCondition( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index cab219216d1c..6a4cc95d36be 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -41,11 +41,11 @@ import org.apache.spark.util.{UninterruptibleThread, Utils} /** * Helper trait that should be extended by all SQL test suites. * - * This allows subclasses to plugin a custom [[SQLContext]]. It comes with test data + * This allows subclasses to plugin a custom `SQLContext`. It comes with test data * prepared in advance as well as all implicit conversions used extensively by dataframes. - * To use implicit methods, import `testImplicits._` instead of through the [[SQLContext]]. + * To use implicit methods, import `testImplicits._` instead of through the `SQLContext`. * - * Subclasses should *not* create [[SQLContext]]s in the test suite constructor, which is + * Subclasses should *not* create `SQLContext`s in the test suite constructor, which is * prone to leaving multiple overlapping [[org.apache.spark.SparkContext]]s in the same JVM. */ private[sql] trait SQLTestUtils @@ -65,7 +65,7 @@ private[sql] trait SQLTestUtils * A helper object for importing SQL implicits. * * Note that the alternative of importing `spark.implicits._` is not possible here. - * This is because we create the [[SQLContext]] immediately before the first test is run, + * This is because we create the `SQLContext` immediately before the first test is run, * but the implicits import is needed in the constructor. */ protected object testImplicits extends SQLImplicits { @@ -73,7 +73,7 @@ private[sql] trait SQLTestUtils } /** - * Materialize the test data immediately after the [[SQLContext]] is set up. + * Materialize the test data immediately after the `SQLContext` is set up. * This is necessary if the data is accessed by name but not through direct reference. */ protected def setupTestData(): Unit = { @@ -250,8 +250,8 @@ private[sql] trait SQLTestUtils } /** - * Turn a logical plan into a [[DataFrame]]. This should be removed once we have an easier - * way to construct [[DataFrame]] directly out of local data without relying on implicits. + * Turn a logical plan into a `DataFrame`. This should be removed once we have an easier + * way to construct `DataFrame` directly out of local data without relying on implicits. */ protected implicit def logicalPlanToSparkQuery(plan: LogicalPlan): DataFrame = { Dataset.ofRows(spark, plan) @@ -271,7 +271,9 @@ private[sql] trait SQLTestUtils } } - /** Run a test on a separate [[UninterruptibleThread]]. */ + /** + * Run a test on a separate `UninterruptibleThread`. + */ protected def testWithUninterruptibleThread(name: String, quietly: Boolean = false) (body: => Unit): Unit = { val timeoutMillis = 10000 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala index b01977a23890..959edf9a4937 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.internal.{SessionState, SessionStateBuilder, SQLConf, WithTestConf} /** - * A special [[SparkSession]] prepared for testing. + * A special `SparkSession` prepared for testing. */ private[sql] class TestSparkSession(sc: SparkContext) extends SparkSession(sc) { self => def this(sparkConf: SparkConf) { diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/Service.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/Service.java index b95077cd6218..0d0e3e4011b5 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/Service.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/Service.java @@ -49,7 +49,7 @@ enum STATE { * The transition must be from {@link STATE#NOTINITED} to {@link STATE#INITED} unless the * operation failed and an exception was raised. * - * @param config + * @param conf * the configuration of the service */ void init(HiveConf conf); diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/ServiceOperations.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/ServiceOperations.java index a2c580d6acc7..c3219aabfc23 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/ServiceOperations.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/ServiceOperations.java @@ -51,7 +51,7 @@ public static void ensureCurrentState(Service.STATE state, /** * Initialize a service. - *

    + * * The service state is checked before the operation begins. * This process is not thread safe. * @param service a service that must be in the state @@ -69,7 +69,7 @@ public static void init(Service service, HiveConf configuration) { /** * Start a service. - *

    + * * The service state is checked before the operation begins. * This process is not thread safe. * @param service a service that must be in the state @@ -86,7 +86,7 @@ public static void start(Service service) { /** * Initialize then start a service. - *

    + * * The service state is checked before the operation begins. * This process is not thread safe. * @param service a service that must be in the state @@ -102,9 +102,9 @@ public static void deploy(Service service, HiveConf configuration) { /** * Stop a service. - *

    Do nothing if the service is null or not - * in a state in which it can be/needs to be stopped. - *

    + * + * Do nothing if the service is null or not in a state in which it can be/needs to be stopped. + * * The service state is checked before the operation begins. * This process is not thread safe. * @param service a service or null diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/HttpAuthUtils.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/HttpAuthUtils.java index 502152829968..f7375ee70783 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/HttpAuthUtils.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/HttpAuthUtils.java @@ -89,7 +89,7 @@ public static String getKerberosServiceTicket(String principal, String host, * @param clientUserName Client User name. * @return An unsigned cookie token generated from input parameters. * The final cookie generated is of the following format : - * cu=&rn=&s= + * {@code cu=&rn=&s=} */ public static String createCookieToken(String clientUserName) { StringBuffer sb = new StringBuffer(); diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/PasswdAuthenticationProvider.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/PasswdAuthenticationProvider.java index e2a6de165adc..1af1c1d06e7f 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/PasswdAuthenticationProvider.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/PasswdAuthenticationProvider.java @@ -26,7 +26,7 @@ public interface PasswdAuthenticationProvider { * to authenticate users for their requests. * If a user is to be granted, return nothing/throw nothing. * When a user is to be disallowed, throw an appropriate {@link AuthenticationException}. - *

    + * * For an example implementation, see {@link LdapAuthenticationProviderImpl}. * * @param user The username received over the connection request diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/TSetIpAddressProcessor.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/TSetIpAddressProcessor.java index 645e3e2bbd4e..9a61ad49942c 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/TSetIpAddressProcessor.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/TSetIpAddressProcessor.java @@ -31,12 +31,9 @@ /** * This class is responsible for setting the ipAddress for operations executed via HiveServer2. - *

    - *

      - *
    • IP address is only set for operations that calls listeners with hookContext
    • - *
    • IP address is only set if the underlying transport mechanism is socket
    • - *
    - *

    + * + * - IP address is only set for operations that calls listeners with hookContext + * - IP address is only set if the underlying transport mechanism is socket * * @see org.apache.hadoop.hive.ql.hooks.ExecuteWithHookContext */ diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/CLIServiceUtils.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/CLIServiceUtils.java index 9d64b102e008..bf2380632fa6 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/CLIServiceUtils.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/CLIServiceUtils.java @@ -38,7 +38,7 @@ public class CLIServiceUtils { * Convert a SQL search pattern into an equivalent Java Regex. * * @param pattern input which may contain '%' or '_' wildcard characters, or - * these characters escaped using {@link #getSearchStringEscape()}. + * these characters escaped using {@code getSearchStringEscape()}. * @return replace %/_ with regex search characters, also handle escaped * characters. */ diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/ClassicTableTypeMapping.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/ClassicTableTypeMapping.java index 05a6bf938404..af36057bdaec 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/ClassicTableTypeMapping.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/ClassicTableTypeMapping.java @@ -28,9 +28,9 @@ /** * ClassicTableTypeMapping. * Classic table type mapping : - * Managed Table ==> Table - * External Table ==> Table - * Virtual View ==> View + * Managed Table to Table + * External Table to Table + * Virtual View to View */ public class ClassicTableTypeMapping implements TableTypeMapping { diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/TableTypeMapping.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/TableTypeMapping.java index e392c459cf58..e59d19ea6be4 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/TableTypeMapping.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/TableTypeMapping.java @@ -31,7 +31,7 @@ public interface TableTypeMapping { /** * Map hive's table type name to client's table type - * @param clientTypeName + * @param hiveTypeName * @return */ String mapToClientType(String hiveTypeName); diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/session/SessionManager.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/session/SessionManager.java index de066dd406c7..c1b3892f5206 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/session/SessionManager.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/session/SessionManager.java @@ -224,7 +224,9 @@ public SessionHandle openSession(TProtocolVersion protocol, String username, Str * The username passed to this method is the effective username. * If withImpersonation is true (==doAs true) we wrap all the calls in HiveSession * within a UGI.doAs, where UGI corresponds to the effective user. - * @see org.apache.hive.service.cli.thrift.ThriftCLIService#getUserName() + * + * Please see {@code org.apache.hive.service.cli.thrift.ThriftCLIService.getUserName()} for + * more details. * * @param protocol * @param username diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/server/ThreadFactoryWithGarbageCleanup.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/server/ThreadFactoryWithGarbageCleanup.java index fb8141a905ac..94f8126552e9 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/server/ThreadFactoryWithGarbageCleanup.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/server/ThreadFactoryWithGarbageCleanup.java @@ -30,12 +30,12 @@ * in custom cleanup code to be called before this thread is GC-ed. * Currently cleans up the following: * 1. ThreadLocal RawStore object: - * In case of an embedded metastore, HiveServer2 threads (foreground & background) + * In case of an embedded metastore, HiveServer2 threads (foreground and background) * end up caching a ThreadLocal RawStore object. The ThreadLocal RawStore object has - * an instance of PersistenceManagerFactory & PersistenceManager. + * an instance of PersistenceManagerFactory and PersistenceManager. * The PersistenceManagerFactory keeps a cache of PersistenceManager objects, * which are only removed when PersistenceManager#close method is called. - * HiveServer2 uses ExecutorService for managing thread pools for foreground & background threads. + * HiveServer2 uses ExecutorService for managing thread pools for foreground and background threads. * ExecutorService unfortunately does not provide any hooks to be called, * when a thread from the pool is terminated. * As a solution, we're using this ThreadFactory to keep a cache of RawStore objects per thread. diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index 6f5b923cd4f9..4dec2f71b8a5 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -53,8 +53,8 @@ import org.apache.spark.unsafe.types.UTF8String * java.sql.Date * java.sql.Timestamp * Complex Types => - * Map: [[MapData]] - * List: [[ArrayData]] + * Map: `MapData` + * List: `ArrayData` * Struct: [[org.apache.spark.sql.catalyst.InternalRow]] * Union: NOT SUPPORTED YET * The Complex types plays as a container, which can hold arbitrary data types. diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQueryFileTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQueryFileTest.scala index e772324a57ab..bb4ce6d3aa3f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQueryFileTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQueryFileTest.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.util._ /** * A framework for running the query tests that are listed as a set of text files. * - * TestSuites that derive from this class must provide a map of testCaseName -> testCaseFiles + * TestSuites that derive from this class must provide a map of testCaseName to testCaseFiles * that should be included. Additionally, there is support for whitelisting and blacklisting * tests as development progresses. */ diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala index 7226ed521ef3..a2f08c5ba72c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala @@ -43,7 +43,7 @@ private[sql] trait OrcTest extends SQLTestUtils with TestHiveSingleton { } /** - * Writes `data` to a Orc file and reads it back as a [[DataFrame]], + * Writes `data` to a Orc file and reads it back as a `DataFrame`, * which is then passed to `f`. The Orc file will be deleted after `f` returns. */ protected def withOrcDataFrame[T <: Product: ClassTag: TypeTag] @@ -53,7 +53,7 @@ private[sql] trait OrcTest extends SQLTestUtils with TestHiveSingleton { } /** - * Writes `data` to a Orc file, reads it back as a [[DataFrame]] and registers it as a + * Writes `data` to a Orc file, reads it back as a `DataFrame` and registers it as a * temporary table named `tableName`, then call `f`. The temporary table together with the * Orc file will be dropped/deleted after `f` returns. */ diff --git a/streaming/src/main/scala/org/apache/spark/streaming/rdd/MapWithStateRDD.scala b/streaming/src/main/scala/org/apache/spark/streaming/rdd/MapWithStateRDD.scala index 58b7031d5ea6..15d3c7e54b8d 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/rdd/MapWithStateRDD.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/rdd/MapWithStateRDD.scala @@ -29,7 +29,7 @@ import org.apache.spark.streaming.util.{EmptyStateMap, StateMap} import org.apache.spark.util.Utils /** - * Record storing the keyed-state [[MapWithStateRDD]]. Each record contains a [[StateMap]] and a + * Record storing the keyed-state [[MapWithStateRDD]]. Each record contains a `StateMap` and a * sequence of records returned by the mapping function of `mapWithState`. */ private[streaming] case class MapWithStateRDDRecord[K, S, E]( @@ -111,7 +111,7 @@ private[streaming] class MapWithStateRDDPartition( /** * RDD storing the keyed states of `mapWithState` operation and corresponding mapped data. * Each partition of this RDD has a single record of type [[MapWithStateRDDRecord]]. This contains a - * [[StateMap]] (containing the keyed-states) and the sequence of records returned by the mapping + * `StateMap` (containing the keyed-states) and the sequence of records returned by the mapping * function of `mapWithState`. * @param prevStateRDD The previous MapWithStateRDD on whose StateMap data `this` RDD * will be created diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/PIDRateEstimator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/PIDRateEstimator.scala index a73e6cc2cd9c..dc02062b9eb4 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/PIDRateEstimator.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/PIDRateEstimator.scala @@ -26,7 +26,7 @@ import org.apache.spark.internal.Logging * case of Spark Streaming the error is the difference between the measured processing * rate (number of elements/processing delay) and the previous rate. * - * @see https://en.wikipedia.org/wiki/PID_controller + * @see PID controller (Wikipedia) * * @param batchIntervalMillis the batch duration, in milliseconds * @param proportional how much the correction should depend on the current diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/RateEstimator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/RateEstimator.scala index 7b2ef6881d6f..e4b9dffee04f 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/RateEstimator.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/RateEstimator.scala @@ -24,7 +24,7 @@ import org.apache.spark.streaming.Duration * A component that estimates the rate at which an `InputDStream` should ingest * records, based on updates at every batch completion. * - * @see [[org.apache.spark.streaming.scheduler.RateController]] + * Please see `org.apache.spark.streaming.scheduler.RateController` for more details. */ private[streaming] trait RateEstimator extends Serializable { From 504e62e2f4b7df7e002ea014a855cebe1ff95193 Mon Sep 17 00:00:00 2001 From: Xiao Li Date: Wed, 12 Apr 2017 09:01:26 -0700 Subject: [PATCH 271/512] [SPARK-20303][SQL] Rename createTempFunction to registerFunction ### What changes were proposed in this pull request? Session catalog API `createTempFunction` is being used by Hive build-in functions, persistent functions, and temporary functions. Thus, the name is confusing. This PR is to rename it by `registerFunction`. Also we can move construction of `FunctionBuilder` and `ExpressionInfo` into the new `registerFunction`, instead of duplicating the logics everywhere. In the next PRs, the remaining Function-related APIs also need cleanups. ### How was this patch tested? Existing test cases. Author: Xiao Li Closes #17615 from gatorsmile/cleanupCreateTempFunction. --- .../analysis/AlreadyExistException.scala | 3 -- .../sql/catalyst/catalog/SessionCatalog.scala | 31 +++++++------- .../catalog/SessionCatalogSuite.scala | 40 ++++++++++--------- .../sql/execution/command/functions.scala | 9 ++--- .../spark/sql/internal/CatalogSuite.scala | 5 ++- .../spark/sql/hive/HiveSessionCatalog.scala | 18 +++------ .../ObjectHashAggregateExecBenchmark.scala | 10 +++-- 7 files changed, 53 insertions(+), 63 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AlreadyExistException.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AlreadyExistException.scala index ec56fe7729c2..57f7a80bedc6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AlreadyExistException.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AlreadyExistException.scala @@ -44,6 +44,3 @@ class PartitionsAlreadyExistException(db: String, table: String, specs: Seq[Tabl class FunctionAlreadyExistsException(db: String, func: String) extends AnalysisException(s"Function '$func' already exists in database '$db'") - -class TempFunctionAlreadyExistsException(func: String) - extends AnalysisException(s"Temporary function '$func' already exists") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index faedf5f91c3e..1417bccf657c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -1050,7 +1050,7 @@ class SessionCatalog( * * This performs reflection to decide what type of [[Expression]] to return in the builder. */ - def makeFunctionBuilder(name: String, functionClassName: String): FunctionBuilder = { + protected def makeFunctionBuilder(name: String, functionClassName: String): FunctionBuilder = { // TODO: at least support UDAFs here throw new UnsupportedOperationException("Use sqlContext.udf.register(...) instead.") } @@ -1064,18 +1064,20 @@ class SessionCatalog( } /** - * Create a temporary function. - * This assumes no database is specified in `funcDefinition`. + * Registers a temporary or permanent function into a session-specific [[FunctionRegistry]] */ - def createTempFunction( - name: String, - info: ExpressionInfo, - funcDefinition: FunctionBuilder, - ignoreIfExists: Boolean): Unit = { - if (functionRegistry.lookupFunctionBuilder(name).isDefined && !ignoreIfExists) { - throw new TempFunctionAlreadyExistsException(name) + def registerFunction( + funcDefinition: CatalogFunction, + ignoreIfExists: Boolean, + functionBuilder: Option[FunctionBuilder] = None): Unit = { + val func = funcDefinition.identifier + if (functionRegistry.functionExists(func.unquotedString) && !ignoreIfExists) { + throw new AnalysisException(s"Function $func already exists") } - functionRegistry.registerFunction(name, info, funcDefinition) + val info = new ExpressionInfo(funcDefinition.className, func.database.orNull, func.funcName) + val builder = + functionBuilder.getOrElse(makeFunctionBuilder(func.unquotedString, funcDefinition.className)) + functionRegistry.registerFunction(func.unquotedString, info, builder) } /** @@ -1180,12 +1182,7 @@ class SessionCatalog( // catalog. So, it is possible that qualifiedName is not exactly the same as // catalogFunction.identifier.unquotedString (difference is on case-sensitivity). // At here, we preserve the input from the user. - val info = new ExpressionInfo( - catalogFunction.className, - qualifiedName.database.orNull, - qualifiedName.funcName) - val builder = makeFunctionBuilder(qualifiedName.unquotedString, catalogFunction.className) - createTempFunction(qualifiedName.unquotedString, info, builder, ignoreIfExists = false) + registerFunction(catalogFunction.copy(identifier = qualifiedName), ignoreIfExists = false) // Now, we need to create the Expression. functionRegistry.lookupFunction(qualifiedName.unquotedString, children) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala index 9ba846fb2527..be8903000a0d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala @@ -1162,10 +1162,10 @@ abstract class SessionCatalogSuite extends PlanTest { withBasicCatalog { catalog => val tempFunc1 = (e: Seq[Expression]) => e.head val tempFunc2 = (e: Seq[Expression]) => e.last - val info1 = new ExpressionInfo("tempFunc1", "temp1") - val info2 = new ExpressionInfo("tempFunc2", "temp2") - catalog.createTempFunction("temp1", info1, tempFunc1, ignoreIfExists = false) - catalog.createTempFunction("temp2", info2, tempFunc2, ignoreIfExists = false) + catalog.registerFunction( + newFunc("temp1", None), ignoreIfExists = false, functionBuilder = Some(tempFunc1)) + catalog.registerFunction( + newFunc("temp2", None), ignoreIfExists = false, functionBuilder = Some(tempFunc2)) val arguments = Seq(Literal(1), Literal(2), Literal(3)) assert(catalog.lookupFunction(FunctionIdentifier("temp1"), arguments) === Literal(1)) assert(catalog.lookupFunction(FunctionIdentifier("temp2"), arguments) === Literal(3)) @@ -1174,13 +1174,15 @@ abstract class SessionCatalogSuite extends PlanTest { catalog.lookupFunction(FunctionIdentifier("temp3"), arguments) } val tempFunc3 = (e: Seq[Expression]) => Literal(e.size) - val info3 = new ExpressionInfo("tempFunc3", "temp1") // Temporary function already exists - intercept[TempFunctionAlreadyExistsException] { - catalog.createTempFunction("temp1", info3, tempFunc3, ignoreIfExists = false) - } + val e = intercept[AnalysisException] { + catalog.registerFunction( + newFunc("temp1", None), ignoreIfExists = false, functionBuilder = Some(tempFunc3)) + }.getMessage + assert(e.contains("Function temp1 already exists")) // Temporary function is overridden - catalog.createTempFunction("temp1", info3, tempFunc3, ignoreIfExists = true) + catalog.registerFunction( + newFunc("temp1", None), ignoreIfExists = true, functionBuilder = Some(tempFunc3)) assert( catalog.lookupFunction( FunctionIdentifier("temp1"), arguments) === Literal(arguments.length)) @@ -1193,8 +1195,8 @@ abstract class SessionCatalogSuite extends PlanTest { assert(!catalog.isTemporaryFunction(FunctionIdentifier("temp1"))) val tempFunc1 = (e: Seq[Expression]) => e.head - val info1 = new ExpressionInfo("tempFunc1", "temp1") - catalog.createTempFunction("temp1", info1, tempFunc1, ignoreIfExists = false) + catalog.registerFunction( + newFunc("temp1", None), ignoreIfExists = false, functionBuilder = Some(tempFunc1)) // Returns true when the function is temporary assert(catalog.isTemporaryFunction(FunctionIdentifier("temp1"))) @@ -1243,9 +1245,9 @@ abstract class SessionCatalogSuite extends PlanTest { test("drop temp function") { withBasicCatalog { catalog => - val info = new ExpressionInfo("tempFunc", "func1") val tempFunc = (e: Seq[Expression]) => e.head - catalog.createTempFunction("func1", info, tempFunc, ignoreIfExists = false) + catalog.registerFunction( + newFunc("func1", None), ignoreIfExists = false, functionBuilder = Some(tempFunc)) val arguments = Seq(Literal(1), Literal(2), Literal(3)) assert(catalog.lookupFunction(FunctionIdentifier("func1"), arguments) === Literal(1)) catalog.dropTempFunction("func1", ignoreIfNotExists = false) @@ -1284,9 +1286,9 @@ abstract class SessionCatalogSuite extends PlanTest { test("lookup temp function") { withBasicCatalog { catalog => - val info1 = new ExpressionInfo("tempFunc1", "func1") val tempFunc1 = (e: Seq[Expression]) => e.head - catalog.createTempFunction("func1", info1, tempFunc1, ignoreIfExists = false) + catalog.registerFunction( + newFunc("func1", None), ignoreIfExists = false, functionBuilder = Some(tempFunc1)) assert(catalog.lookupFunction( FunctionIdentifier("func1"), Seq(Literal(1), Literal(2), Literal(3))) == Literal(1)) catalog.dropTempFunction("func1", ignoreIfNotExists = false) @@ -1298,14 +1300,14 @@ abstract class SessionCatalogSuite extends PlanTest { test("list functions") { withBasicCatalog { catalog => - val info1 = new ExpressionInfo("tempFunc1", "func1") - val info2 = new ExpressionInfo("tempFunc2", "yes_me") + val funcMeta1 = newFunc("func1", None) + val funcMeta2 = newFunc("yes_me", None) val tempFunc1 = (e: Seq[Expression]) => e.head val tempFunc2 = (e: Seq[Expression]) => e.last catalog.createFunction(newFunc("func2", Some("db2")), ignoreIfExists = false) catalog.createFunction(newFunc("not_me", Some("db2")), ignoreIfExists = false) - catalog.createTempFunction("func1", info1, tempFunc1, ignoreIfExists = false) - catalog.createTempFunction("yes_me", info2, tempFunc2, ignoreIfExists = false) + catalog.registerFunction(funcMeta1, ignoreIfExists = false, functionBuilder = Some(tempFunc1)) + catalog.registerFunction(funcMeta2, ignoreIfExists = false, functionBuilder = Some(tempFunc2)) assert(catalog.listFunctions("db1", "*").map(_._1).toSet == Set(FunctionIdentifier("func1"), FunctionIdentifier("yes_me"))) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala index 5687f9332430..e0d002936957 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala @@ -51,6 +51,7 @@ case class CreateFunctionCommand( override def run(sparkSession: SparkSession): Seq[Row] = { val catalog = sparkSession.sessionState.catalog + val func = CatalogFunction(FunctionIdentifier(functionName, databaseName), className, resources) if (isTemp) { if (databaseName.isDefined) { throw new AnalysisException(s"Specifying a database in CREATE TEMPORARY FUNCTION " + @@ -59,17 +60,13 @@ case class CreateFunctionCommand( // We first load resources and then put the builder in the function registry. // Please note that it is allowed to overwrite an existing temp function. catalog.loadFunctionResources(resources) - val info = new ExpressionInfo(className, functionName) - val builder = catalog.makeFunctionBuilder(functionName, className) - catalog.createTempFunction(functionName, info, builder, ignoreIfExists = false) + catalog.registerFunction(func, ignoreIfExists = false) } else { // For a permanent, we will store the metadata into underlying external catalog. // This function will be loaded into the FunctionRegistry when a query uses it. // We do not load it into FunctionRegistry right now. // TODO: should we also parse "IF NOT EXISTS"? - catalog.createFunction( - CatalogFunction(FunctionIdentifier(functionName, databaseName), className, resources), - ignoreIfExists = false) + catalog.createFunction(func, ignoreIfExists = false) } Seq.empty[Row] } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala index 6469e501c1f6..8f9c52cb1e03 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala @@ -75,9 +75,10 @@ class CatalogSuite } private def createTempFunction(name: String): Unit = { - val info = new ExpressionInfo("className", name) val tempFunc = (e: Seq[Expression]) => e.head - sessionCatalog.createTempFunction(name, info, tempFunc, ignoreIfExists = false) + val funcMeta = CatalogFunction(FunctionIdentifier(name, None), "className", Nil) + sessionCatalog.registerFunction( + funcMeta, ignoreIfExists = false, functionBuilder = Some(tempFunc)) } private def dropFunction(name: String, db: Option[String] = None): Unit = { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala index c917f110b90f..377d4f2473c5 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala @@ -31,8 +31,8 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.FunctionIdentifier import org.apache.spark.sql.catalyst.analysis.FunctionRegistry import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder -import org.apache.spark.sql.catalyst.catalog.{FunctionResourceLoader, GlobalTempViewManager, SessionCatalog} -import org.apache.spark.sql.catalyst.expressions.{Cast, Expression, ExpressionInfo} +import org.apache.spark.sql.catalyst.catalog.{CatalogFunction, FunctionResourceLoader, GlobalTempViewManager, SessionCatalog} +import org.apache.spark.sql.catalyst.expressions.{Cast, Expression} import org.apache.spark.sql.catalyst.parser.ParserInterface import org.apache.spark.sql.hive.HiveShim.HiveFunctionWrapper import org.apache.spark.sql.internal.SQLConf @@ -124,13 +124,6 @@ private[sql] class HiveSessionCatalog( } private def lookupFunction0(name: FunctionIdentifier, children: Seq[Expression]): Expression = { - // TODO: Once lookupFunction accepts a FunctionIdentifier, we should refactor this method to - // if (super.functionExists(name)) { - // super.lookupFunction(name, children) - // } else { - // // This function is a Hive builtin function. - // ... - // } val database = name.database.map(formatDatabaseName) val funcName = name.copy(database = database) Try(super.lookupFunction(funcName, children)) match { @@ -164,10 +157,11 @@ private[sql] class HiveSessionCatalog( } } val className = functionInfo.getFunctionClass.getName - val builder = makeFunctionBuilder(functionName, className) + val functionIdentifier = + FunctionIdentifier(functionName.toLowerCase(Locale.ROOT), database) + val func = CatalogFunction(functionIdentifier, className, Nil) // Put this Hive built-in function to our function registry. - val info = new ExpressionInfo(className, functionName) - createTempFunction(functionName, info, builder, ignoreIfExists = false) + registerFunction(func, ignoreIfExists = false) // Now, we need to create the Expression. functionRegistry.lookupFunction(functionName, children) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/execution/benchmark/ObjectHashAggregateExecBenchmark.scala b/sql/hive/src/test/scala/org/apache/spark/sql/execution/benchmark/ObjectHashAggregateExecBenchmark.scala index 197110f4912a..73383ae4d411 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/execution/benchmark/ObjectHashAggregateExecBenchmark.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/execution/benchmark/ObjectHashAggregateExecBenchmark.scala @@ -22,7 +22,9 @@ import scala.concurrent.duration._ import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFPercentileApprox import org.apache.spark.sql.Column -import org.apache.spark.sql.catalyst.expressions.{ExpressionInfo, Literal} +import org.apache.spark.sql.catalyst.FunctionIdentifier +import org.apache.spark.sql.catalyst.catalog.CatalogFunction +import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.catalyst.expressions.aggregate.ApproximatePercentile import org.apache.spark.sql.hive.HiveSessionCatalog import org.apache.spark.sql.hive.execution.TestingTypedCount @@ -217,9 +219,9 @@ class ObjectHashAggregateExecBenchmark extends BenchmarkBase with TestHiveSingle private def registerHiveFunction(functionName: String, clazz: Class[_]): Unit = { val sessionCatalog = sparkSession.sessionState.catalog.asInstanceOf[HiveSessionCatalog] - val builder = sessionCatalog.makeFunctionBuilder(functionName, clazz.getName) - val info = new ExpressionInfo(clazz.getName, functionName) - sessionCatalog.createTempFunction(functionName, info, builder, ignoreIfExists = false) + val functionIdentifier = FunctionIdentifier(functionName, database = None) + val func = CatalogFunction(functionIdentifier, clazz.getName, resources = Nil) + sessionCatalog.registerFunction(func, ignoreIfExists = false) } private def percentile_approx( From 540855382c8f139fbf4eb0800b31c7ce91f29c7f Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 12 Apr 2017 09:05:05 -0700 Subject: [PATCH 272/512] [SPARK-20304][SQL] AssertNotNull should not include path in string representation ## What changes were proposed in this pull request? AssertNotNull's toString/simpleString dumps the entire walkedTypePath. walkedTypePath is used for error message reporting and shouldn't be part of the output. ## How was this patch tested? Manually tested. Author: Reynold Xin Closes #17616 from rxin/SPARK-20304. --- .../apache/spark/sql/catalyst/expressions/objects/objects.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 6d94764f1bfa..eed773d4cb36 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -996,6 +996,8 @@ case class AssertNotNull(child: Expression, walkedTypePath: Seq[String] = Nil) override def foldable: Boolean = false override def nullable: Boolean = false + override def flatArguments: Iterator[Any] = Iterator(child) + private val errMsg = "Null value appeared in non-nullable field:" + walkedTypePath.mkString("\n", "\n", "\n") + "If the schema is inferred from a Scala tuple/case class, or a Java bean, " + From 99a9473127ec389283ac4ec3b721d2e34434e647 Mon Sep 17 00:00:00 2001 From: Jeff Zhang Date: Wed, 12 Apr 2017 10:54:50 -0700 Subject: [PATCH 273/512] [SPARK-19570][PYSPARK] Allow to disable hive in pyspark shell ## What changes were proposed in this pull request? SPARK-15236 do this for scala shell, this ticket is for pyspark shell. This is not only for pyspark itself, but can also benefit downstream project like livy which use shell.py for its interactive session. For now, livy has no control of whether enable hive or not. ## How was this patch tested? I didn't find a way to add test for it. Just manually test it. Run `bin/pyspark --master local --conf spark.sql.catalogImplementation=in-memory` and verify hive is not enabled. Author: Jeff Zhang Closes #16906 from zjffdu/SPARK-19570. --- python/pyspark/shell.py | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/python/pyspark/shell.py b/python/pyspark/shell.py index c1917d2be69d..b5fcf7092d93 100644 --- a/python/pyspark/shell.py +++ b/python/pyspark/shell.py @@ -24,13 +24,13 @@ import atexit import os import platform +import warnings import py4j -import pyspark +from pyspark import SparkConf from pyspark.context import SparkContext from pyspark.sql import SparkSession, SQLContext -from pyspark.storagelevel import StorageLevel if os.environ.get("SPARK_EXECUTOR_URI"): SparkContext.setSystemProperty("spark.executor.uri", os.environ["SPARK_EXECUTOR_URI"]) @@ -39,13 +39,23 @@ try: # Try to access HiveConf, it will raise exception if Hive is not added - SparkContext._jvm.org.apache.hadoop.hive.conf.HiveConf() - spark = SparkSession.builder\ - .enableHiveSupport()\ - .getOrCreate() + conf = SparkConf() + if conf.get('spark.sql.catalogImplementation', 'hive').lower() == 'hive': + SparkContext._jvm.org.apache.hadoop.hive.conf.HiveConf() + spark = SparkSession.builder\ + .enableHiveSupport()\ + .getOrCreate() + else: + spark = SparkSession.builder.getOrCreate() except py4j.protocol.Py4JError: + if conf.get('spark.sql.catalogImplementation', '').lower() == 'hive': + warnings.warn("Fall back to non-hive support because failing to access HiveConf, " + "please make sure you build spark with hive") spark = SparkSession.builder.getOrCreate() except TypeError: + if conf.get('spark.sql.catalogImplementation', '').lower() == 'hive': + warnings.warn("Fall back to non-hive support because failing to access HiveConf, " + "please make sure you build spark with hive") spark = SparkSession.builder.getOrCreate() sc = spark.sparkContext From 924c42477b5d6ed3c217c8eaaf4dc64b2379851a Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Wed, 12 Apr 2017 11:24:59 -0700 Subject: [PATCH 274/512] [SPARK-20301][FLAKY-TEST] Fix Hadoop Shell.runCommand flakiness in Structured Streaming tests ## What changes were proposed in this pull request? Some Structured Streaming tests show flakiness such as: ``` [info] - prune results by current_date, complete mode - 696 *** FAILED *** (10 seconds, 937 milliseconds) [info] Timed out while stopping and waiting for microbatchthread to terminate.: The code passed to failAfter did not complete within 10 seconds. ``` This happens when we wait for the stream to stop, but it doesn't. The reason it doesn't stop is that we interrupt the microBatchThread, but Hadoop's `Shell.runCommand` swallows the interrupt exception, and the exception is not propagated upstream to the microBatchThread. Then this thread continues to run, only to start blocking on the `streamManualClock`. ## How was this patch tested? Thousand retries locally and [Jenkins](https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder/75720/testReport) of the flaky tests Author: Burak Yavuz Closes #17613 from brkyvz/flaky-stream-agg. --- .../execution/streaming/StreamExecution.scala | 56 +++++++++---------- .../spark/sql/streaming/StreamTest.scala | 6 ++ 2 files changed, 32 insertions(+), 30 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index 8857966676ae..bcf0d970f7ec 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -284,42 +284,38 @@ class StreamExecution( triggerExecutor.execute(() => { startTrigger() - val continueToRun = - if (isActive) { - reportTimeTaken("triggerExecution") { - if (currentBatchId < 0) { - // We'll do this initialization only once - populateStartOffsets(sparkSessionToRunBatches) - logDebug(s"Stream running from $committedOffsets to $availableOffsets") - } else { - constructNextBatch() - } - if (dataAvailable) { - currentStatus = currentStatus.copy(isDataAvailable = true) - updateStatusMessage("Processing new data") - runBatch(sparkSessionToRunBatches) - } + if (isActive) { + reportTimeTaken("triggerExecution") { + if (currentBatchId < 0) { + // We'll do this initialization only once + populateStartOffsets(sparkSessionToRunBatches) + logDebug(s"Stream running from $committedOffsets to $availableOffsets") + } else { + constructNextBatch() } - // Report trigger as finished and construct progress object. - finishTrigger(dataAvailable) if (dataAvailable) { - // Update committed offsets. - batchCommitLog.add(currentBatchId) - committedOffsets ++= availableOffsets - logDebug(s"batch ${currentBatchId} committed") - // We'll increase currentBatchId after we complete processing current batch's data - currentBatchId += 1 - } else { - currentStatus = currentStatus.copy(isDataAvailable = false) - updateStatusMessage("Waiting for data to arrive") - Thread.sleep(pollingDelayMs) + currentStatus = currentStatus.copy(isDataAvailable = true) + updateStatusMessage("Processing new data") + runBatch(sparkSessionToRunBatches) } - true + } + // Report trigger as finished and construct progress object. + finishTrigger(dataAvailable) + if (dataAvailable) { + // Update committed offsets. + batchCommitLog.add(currentBatchId) + committedOffsets ++= availableOffsets + logDebug(s"batch ${currentBatchId} committed") + // We'll increase currentBatchId after we complete processing current batch's data + currentBatchId += 1 } else { - false + currentStatus = currentStatus.copy(isDataAvailable = false) + updateStatusMessage("Waiting for data to arrive") + Thread.sleep(pollingDelayMs) } + } updateStatusMessage("Waiting for next trigger") - continueToRun + isActive }) updateStatusMessage("Stopped") } else { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index 03aa45b61688..5bc36dd30f6d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -277,6 +277,11 @@ trait StreamTest extends QueryTest with SharedSQLContext with Timeouts { def threadState = if (currentStream != null && currentStream.microBatchThread.isAlive) "alive" else "dead" + def threadStackTrace = if (currentStream != null && currentStream.microBatchThread.isAlive) { + s"Thread stack trace: ${currentStream.microBatchThread.getStackTrace.mkString("\n")}" + } else { + "" + } def testState = s""" @@ -287,6 +292,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with Timeouts { |Output Mode: $outputMode |Stream state: $currentOffsets |Thread state: $threadState + |$threadStackTrace |${if (streamThreadDeathCause != null) stackTraceToString(streamThreadDeathCause) else ""} | |== Sink == From a7b430b5717e263c1fbb55114deca6028ea9c3b3 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 13 Apr 2017 08:38:24 +0800 Subject: [PATCH 275/512] [SPARK-15354][FLAKY-TEST] TopologyAwareBlockReplicationPolicyBehavior.Peers in 2 racks ## What changes were proposed in this pull request? `TopologyAwareBlockReplicationPolicyBehavior.Peers in 2 racks` is failing occasionally: https://spark-tests.appspot.com/test-details?suite_name=org.apache.spark.storage.TopologyAwareBlockReplicationPolicyBehavior&test_name=Peers+in+2+racks. This is because, when we generate 10 block manager id to test, they may all belong to the same rack, as the rack is randomly picked. This PR fixes this problem by forcing each rack to be picked at least once. ## How was this patch tested? N/A Author: Wenchen Fan Closes #17624 from cloud-fan/test. --- .../spark/storage/BlockReplicationPolicySuite.scala | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/storage/BlockReplicationPolicySuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockReplicationPolicySuite.scala index ecad0f5352e5..dfecd04c1b96 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockReplicationPolicySuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockReplicationPolicySuite.scala @@ -70,9 +70,18 @@ class RandomBlockReplicationPolicyBehavior extends SparkFunSuite } } + /** + * Returns a sequence of [[BlockManagerId]], whose rack is randomly picked from the given `racks`. + * Note that, each rack will be picked at least once from `racks`, if `count` is greater or equal + * to the number of `racks`. + */ protected def generateBlockManagerIds(count: Int, racks: Seq[String]): Seq[BlockManagerId] = { - (1 to count).map{i => - BlockManagerId(s"Exec-$i", s"Host-$i", 10000 + i, Some(racks(Random.nextInt(racks.size)))) + val randomizedRacks: Seq[String] = Random.shuffle( + racks ++ racks.length.until(count).map(_ => racks(Random.nextInt(racks.length))) + ) + + (0 until count).map { i => + BlockManagerId(s"Exec-$i", s"Host-$i", 10000 + i, Some(randomizedRacks(i))) } } } From c5f1cc370f0aa1f0151fd34251607a8de861395e Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Wed, 12 Apr 2017 17:44:18 -0700 Subject: [PATCH 276/512] [SPARK-20131][CORE] Don't use `this` lock in StandaloneSchedulerBackend.stop ## What changes were proposed in this pull request? `o.a.s.streaming.StreamingContextSuite.SPARK-18560 Receiver data should be deserialized properly` is flaky is because there is a potential dead-lock in StandaloneSchedulerBackend which causes `await` timeout. Here is the related stack trace: ``` "Thread-31" #211 daemon prio=5 os_prio=31 tid=0x00007fedd4808000 nid=0x16403 waiting on condition [0x00007000239b7000] java.lang.Thread.State: TIMED_WAITING (parking) at sun.misc.Unsafe.park(Native Method) - parking to wait for <0x000000079b49ca10> (a scala.concurrent.impl.Promise$CompletionLatch) at java.util.concurrent.locks.LockSupport.parkNanos(LockSupport.java:215) at java.util.concurrent.locks.AbstractQueuedSynchronizer.doAcquireSharedNanos(AbstractQueuedSynchronizer.java:1037) at java.util.concurrent.locks.AbstractQueuedSynchronizer.tryAcquireSharedNanos(AbstractQueuedSynchronizer.java:1328) at scala.concurrent.impl.Promise$DefaultPromise.tryAwait(Promise.scala:208) at scala.concurrent.impl.Promise$DefaultPromise.ready(Promise.scala:218) at scala.concurrent.impl.Promise$DefaultPromise.result(Promise.scala:223) at org.apache.spark.util.ThreadUtils$.awaitResult(ThreadUtils.scala:201) at org.apache.spark.rpc.RpcTimeout.awaitResult(RpcTimeout.scala:75) at org.apache.spark.rpc.RpcEndpointRef.askSync(RpcEndpointRef.scala:92) at org.apache.spark.rpc.RpcEndpointRef.askSync(RpcEndpointRef.scala:76) at org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend.stop(CoarseGrainedSchedulerBackend.scala:402) at org.apache.spark.scheduler.cluster.StandaloneSchedulerBackend.org$apache$spark$scheduler$cluster$StandaloneSchedulerBackend$$stop(StandaloneSchedulerBackend.scala:213) - locked <0x00000007066fca38> (a org.apache.spark.scheduler.cluster.StandaloneSchedulerBackend) at org.apache.spark.scheduler.cluster.StandaloneSchedulerBackend.stop(StandaloneSchedulerBackend.scala:116) - locked <0x00000007066fca38> (a org.apache.spark.scheduler.cluster.StandaloneSchedulerBackend) at org.apache.spark.scheduler.TaskSchedulerImpl.stop(TaskSchedulerImpl.scala:517) at org.apache.spark.scheduler.DAGScheduler.stop(DAGScheduler.scala:1657) at org.apache.spark.SparkContext$$anonfun$stop$8.apply$mcV$sp(SparkContext.scala:1921) at org.apache.spark.util.Utils$.tryLogNonFatalError(Utils.scala:1302) at org.apache.spark.SparkContext.stop(SparkContext.scala:1920) at org.apache.spark.streaming.StreamingContext.stop(StreamingContext.scala:708) at org.apache.spark.streaming.StreamingContextSuite$$anonfun$43$$anonfun$apply$mcV$sp$66$$anon$3.run(StreamingContextSuite.scala:827) "dispatcher-event-loop-3" #18 daemon prio=5 os_prio=31 tid=0x00007fedd603a000 nid=0x6203 waiting for monitor entry [0x0000700003be4000] java.lang.Thread.State: BLOCKED (on object monitor) at org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend$DriverEndpoint.org$apache$spark$scheduler$cluster$CoarseGrainedSchedulerBackend$DriverEndpoint$$makeOffers(CoarseGrainedSchedulerBackend.scala:253) - waiting to lock <0x00000007066fca38> (a org.apache.spark.scheduler.cluster.StandaloneSchedulerBackend) at org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend$DriverEndpoint$$anonfun$receive$1.applyOrElse(CoarseGrainedSchedulerBackend.scala:124) at org.apache.spark.rpc.netty.Inbox$$anonfun$process$1.apply$mcV$sp(Inbox.scala:117) at org.apache.spark.rpc.netty.Inbox.safelyCall(Inbox.scala:205) at org.apache.spark.rpc.netty.Inbox.process(Inbox.scala:101) at org.apache.spark.rpc.netty.Dispatcher$MessageLoop.run(Dispatcher.scala:213) at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142) at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617) at java.lang.Thread.run(Thread.java:745) ``` This PR removes `synchronized` and changes `stopping` to AtomicBoolean to ensure idempotent to fix the dead-lock. ## How was this patch tested? Jenkins Author: Shixiong Zhu Closes #17610 from zsxwing/SPARK-20131. --- .../cluster/StandaloneSchedulerBackend.scala | 33 ++++++++++--------- 1 file changed, 17 insertions(+), 16 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala index 7befdb0c1f64..0529fe9eed4d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala @@ -18,6 +18,7 @@ package org.apache.spark.scheduler.cluster import java.util.concurrent.Semaphore +import java.util.concurrent.atomic.AtomicBoolean import scala.concurrent.Future @@ -42,7 +43,7 @@ private[spark] class StandaloneSchedulerBackend( with Logging { private var client: StandaloneAppClient = null - private var stopping = false + private val stopping = new AtomicBoolean(false) private val launcherBackend = new LauncherBackend() { override protected def onStopRequest(): Unit = stop(SparkAppHandle.State.KILLED) } @@ -112,7 +113,7 @@ private[spark] class StandaloneSchedulerBackend( launcherBackend.setState(SparkAppHandle.State.RUNNING) } - override def stop(): Unit = synchronized { + override def stop(): Unit = { stop(SparkAppHandle.State.FINISHED) } @@ -125,14 +126,14 @@ private[spark] class StandaloneSchedulerBackend( override def disconnected() { notifyContext() - if (!stopping) { + if (!stopping.get) { logWarning("Disconnected from Spark cluster! Waiting for reconnection...") } } override def dead(reason: String) { notifyContext() - if (!stopping) { + if (!stopping.get) { launcherBackend.setState(SparkAppHandle.State.KILLED) logError("Application has been killed. Reason: " + reason) try { @@ -206,20 +207,20 @@ private[spark] class StandaloneSchedulerBackend( registrationBarrier.release() } - private def stop(finalState: SparkAppHandle.State): Unit = synchronized { - try { - stopping = true - - super.stop() - client.stop() + private def stop(finalState: SparkAppHandle.State): Unit = { + if (stopping.compareAndSet(false, true)) { + try { + super.stop() + client.stop() - val callback = shutdownCallback - if (callback != null) { - callback(this) + val callback = shutdownCallback + if (callback != null) { + callback(this) + } + } finally { + launcherBackend.setState(finalState) + launcherBackend.close() } - } finally { - launcherBackend.setState(finalState) - launcherBackend.close() } } From ec68d8f8cfdede8a0de1d56476205158544cc4eb Mon Sep 17 00:00:00 2001 From: Yash Sharma Date: Thu, 13 Apr 2017 08:49:19 +0100 Subject: [PATCH 277/512] [SPARK-20189][DSTREAM] Fix spark kinesis testcases to remove deprecated createStream and use Builders ## What changes were proposed in this pull request? The spark-kinesis testcases use the KinesisUtils.createStream which are deprecated now. Modify the testcases to use the recommended KinesisInputDStream.builder instead. This change will also enable the testcases to automatically use the session tokens automatically. ## How was this patch tested? All the existing testcases work fine as expected with the changes. https://issues.apache.org/jira/browse/SPARK-20189 Author: Yash Sharma Closes #17506 from yssharma/ysharma/cleanup_kinesis_testcases. --- .../kinesis/KinesisInputDStream.scala | 2 +- .../kinesis/KinesisStreamSuite.scala | 58 ++++++++++++------- 2 files changed, 38 insertions(+), 22 deletions(-) diff --git a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala index 8970ad2bafda..77553412eda5 100644 --- a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala +++ b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala @@ -267,7 +267,7 @@ object KinesisInputDStream { getRequiredParam(checkpointAppName, "checkpointAppName"), checkpointInterval.getOrElse(ssc.graph.batchDuration), storageLevel.getOrElse(DEFAULT_STORAGE_LEVEL), - handler, + ssc.sc.clean(handler), kinesisCredsProvider.getOrElse(DefaultCredentials), dynamoDBCredsProvider, cloudWatchCredsProvider) diff --git a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala index ed7e35805026..341a6898cbbf 100644 --- a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala +++ b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala @@ -22,7 +22,6 @@ import scala.concurrent.duration._ import scala.language.postfixOps import scala.util.Random -import com.amazonaws.regions.RegionUtils import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream import com.amazonaws.services.kinesis.model.Record import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} @@ -173,11 +172,15 @@ abstract class KinesisStreamTests(aggregateTestData: Boolean) extends KinesisFun * and you have to set the system environment variable RUN_KINESIS_TESTS=1 . */ testIfEnabled("basic operation") { - val awsCredentials = KinesisTestUtils.getAWSCredentials() - val stream = KinesisUtils.createStream(ssc, appName, testUtils.streamName, - testUtils.endpointUrl, testUtils.regionName, InitialPositionInStream.LATEST, - Seconds(10), StorageLevel.MEMORY_ONLY, - awsCredentials.getAWSAccessKeyId, awsCredentials.getAWSSecretKey) + val stream = KinesisInputDStream.builder.streamingContext(ssc) + .checkpointAppName(appName) + .streamName(testUtils.streamName) + .endpointUrl(testUtils.endpointUrl) + .regionName(testUtils.regionName) + .initialPositionInStream(InitialPositionInStream.LATEST) + .checkpointInterval(Seconds(10)) + .storageLevel(StorageLevel.MEMORY_ONLY) + .build() val collected = new mutable.HashSet[Int] stream.map { bytes => new String(bytes).toInt }.foreachRDD { rdd => @@ -198,12 +201,17 @@ abstract class KinesisStreamTests(aggregateTestData: Boolean) extends KinesisFun } testIfEnabled("custom message handling") { - val awsCredentials = KinesisTestUtils.getAWSCredentials() def addFive(r: Record): Int = JavaUtils.bytesToString(r.getData).toInt + 5 - val stream = KinesisUtils.createStream(ssc, appName, testUtils.streamName, - testUtils.endpointUrl, testUtils.regionName, InitialPositionInStream.LATEST, - Seconds(10), StorageLevel.MEMORY_ONLY, addFive(_), - awsCredentials.getAWSAccessKeyId, awsCredentials.getAWSSecretKey) + + val stream = KinesisInputDStream.builder.streamingContext(ssc) + .checkpointAppName(appName) + .streamName(testUtils.streamName) + .endpointUrl(testUtils.endpointUrl) + .regionName(testUtils.regionName) + .initialPositionInStream(InitialPositionInStream.LATEST) + .checkpointInterval(Seconds(10)) + .storageLevel(StorageLevel.MEMORY_ONLY) + .buildWithMessageHandler(addFive(_)) stream shouldBe a [ReceiverInputDStream[_]] @@ -233,11 +241,15 @@ abstract class KinesisStreamTests(aggregateTestData: Boolean) extends KinesisFun val localTestUtils = new KPLBasedKinesisTestUtils(1) localTestUtils.createStream() try { - val awsCredentials = KinesisTestUtils.getAWSCredentials() - val stream = KinesisUtils.createStream(ssc, localAppName, localTestUtils.streamName, - localTestUtils.endpointUrl, localTestUtils.regionName, InitialPositionInStream.LATEST, - Seconds(10), StorageLevel.MEMORY_ONLY, - awsCredentials.getAWSAccessKeyId, awsCredentials.getAWSSecretKey) + val stream = KinesisInputDStream.builder.streamingContext(ssc) + .checkpointAppName(localAppName) + .streamName(localTestUtils.streamName) + .endpointUrl(localTestUtils.endpointUrl) + .regionName(localTestUtils.regionName) + .initialPositionInStream(InitialPositionInStream.LATEST) + .checkpointInterval(Seconds(10)) + .storageLevel(StorageLevel.MEMORY_ONLY) + .build() val collected = new mutable.HashSet[Int] stream.map { bytes => new String(bytes).toInt }.foreachRDD { rdd => @@ -303,13 +315,17 @@ abstract class KinesisStreamTests(aggregateTestData: Boolean) extends KinesisFun ssc = new StreamingContext(sc, Milliseconds(1000)) ssc.checkpoint(checkpointDir) - val awsCredentials = KinesisTestUtils.getAWSCredentials() val collectedData = new mutable.HashMap[Time, (Array[SequenceNumberRanges], Seq[Int])] - val kinesisStream = KinesisUtils.createStream(ssc, appName, testUtils.streamName, - testUtils.endpointUrl, testUtils.regionName, InitialPositionInStream.LATEST, - Seconds(10), StorageLevel.MEMORY_ONLY, - awsCredentials.getAWSAccessKeyId, awsCredentials.getAWSSecretKey) + val kinesisStream = KinesisInputDStream.builder.streamingContext(ssc) + .checkpointAppName(appName) + .streamName(testUtils.streamName) + .endpointUrl(testUtils.endpointUrl) + .regionName(testUtils.regionName) + .initialPositionInStream(InitialPositionInStream.LATEST) + .checkpointInterval(Seconds(10)) + .storageLevel(StorageLevel.MEMORY_ONLY) + .build() // Verify that the generated RDDs are KinesisBackedBlockRDDs, and collect the data in each batch kinesisStream.foreachRDD((rdd: RDD[Array[Byte]], time: Time) => { From 095d1cb3aa0021c9078a6e910967b9189ddfa177 Mon Sep 17 00:00:00 2001 From: Syrux Date: Thu, 13 Apr 2017 09:44:33 +0100 Subject: [PATCH 278/512] [SPARK-20265][MLLIB] Improve Prefix'span pre-processing efficiency ## What changes were proposed in this pull request? Improve PrefixSpan pre-processing efficency by preventing sequences of zero in the cleaned database. The efficiency gain is reflected in the following graph : https://postimg.org/image/9x6ireuvn/ ## How was this patch tested? Using MLlib's PrefixSpan existing tests and tests of my own on the 8 datasets shown in the graph. All result obtained were stricly the same as the original implementation (without this change). dev/run-tests was also runned, no error were found. Author : Cyril de Vogelaere Author: Syrux Closes #17575 from Syrux/SPARK-20265. --- .../apache/spark/mllib/fpm/PrefixSpan.scala | 99 ++++++++++++------- .../spark/mllib/fpm/PrefixSpanSuite.scala | 51 ++++++++++ 2 files changed, 115 insertions(+), 35 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala index 327cb974ef96..3f8d65a378e2 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala @@ -144,45 +144,13 @@ class PrefixSpan private ( logInfo(s"minimum count for a frequent pattern: $minCount") // Find frequent items. - val freqItemAndCounts = data.flatMap { itemsets => - val uniqItems = mutable.Set.empty[Item] - itemsets.foreach { _.foreach { item => - uniqItems += item - }} - uniqItems.toIterator.map((_, 1L)) - }.reduceByKey(_ + _) - .filter { case (_, count) => - count >= minCount - }.collect() - val freqItems = freqItemAndCounts.sortBy(-_._2).map(_._1) + val freqItems = findFrequentItems(data, minCount) logInfo(s"number of frequent items: ${freqItems.length}") // Keep only frequent items from input sequences and convert them to internal storage. val itemToInt = freqItems.zipWithIndex.toMap - val dataInternalRepr = data.flatMap { itemsets => - val allItems = mutable.ArrayBuilder.make[Int] - var containsFreqItems = false - allItems += 0 - itemsets.foreach { itemsets => - val items = mutable.ArrayBuilder.make[Int] - itemsets.foreach { item => - if (itemToInt.contains(item)) { - items += itemToInt(item) + 1 // using 1-indexing in internal format - } - } - val result = items.result() - if (result.nonEmpty) { - containsFreqItems = true - allItems ++= result.sorted - } - allItems += 0 - } - if (containsFreqItems) { - Iterator.single(allItems.result()) - } else { - Iterator.empty - } - }.persist(StorageLevel.MEMORY_AND_DISK) + val dataInternalRepr = toDatabaseInternalRepr(data, itemToInt) + .persist(StorageLevel.MEMORY_AND_DISK) val results = genFreqPatterns(dataInternalRepr, minCount, maxPatternLength, maxLocalProjDBSize) @@ -231,6 +199,67 @@ class PrefixSpan private ( @Since("1.5.0") object PrefixSpan extends Logging { + /** + * This methods finds all frequent items in a input dataset. + * + * @param data Sequences of itemsets. + * @param minCount The minimal number of sequence an item should be present in to be frequent + * + * @return An array of Item containing only frequent items. + */ + private[fpm] def findFrequentItems[Item: ClassTag]( + data: RDD[Array[Array[Item]]], + minCount: Long): Array[Item] = { + + data.flatMap { itemsets => + val uniqItems = mutable.Set.empty[Item] + itemsets.foreach(set => uniqItems ++= set) + uniqItems.toIterator.map((_, 1L)) + }.reduceByKey(_ + _).filter { case (_, count) => + count >= minCount + }.sortBy(-_._2).map(_._1).collect() + } + + /** + * This methods cleans the input dataset from un-frequent items, and translate it's item + * to their corresponding Int identifier. + * + * @param data Sequences of itemsets. + * @param itemToInt A map allowing translation of frequent Items to their Int Identifier. + * The map should only contain frequent item. + * + * @return The internal repr of the inputted dataset. With properly placed zero delimiter. + */ + private[fpm] def toDatabaseInternalRepr[Item: ClassTag]( + data: RDD[Array[Array[Item]]], + itemToInt: Map[Item, Int]): RDD[Array[Int]] = { + + data.flatMap { itemsets => + val allItems = mutable.ArrayBuilder.make[Int] + var containsFreqItems = false + allItems += 0 + itemsets.foreach { itemsets => + val items = mutable.ArrayBuilder.make[Int] + itemsets.foreach { item => + if (itemToInt.contains(item)) { + items += itemToInt(item) + 1 // using 1-indexing in internal format + } + } + val result = items.result() + if (result.nonEmpty) { + containsFreqItems = true + allItems ++= result.sorted + allItems += 0 + } + } + if (containsFreqItems) { + Iterator.single(allItems.result()) + } else { + Iterator.empty + } + } + } + /** * Find the complete set of frequent sequential patterns in the input sequences. * @param data ordered sequences of itemsets. We represent a sequence internally as Array[Int], diff --git a/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala index 4c2376376dd2..c2e08d078fc1 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala @@ -360,6 +360,49 @@ class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext { compareResults(expected, model.freqSequences.collect()) } + test("PrefixSpan pre-processing's cleaning test") { + + // One item per itemSet + val itemToInt1 = (4 to 5).zipWithIndex.toMap + val sequences1 = Seq( + Array(Array(4), Array(1), Array(2), Array(5), Array(2), Array(4), Array(5)), + Array(Array(6), Array(7), Array(8))) + val rdd1 = sc.parallelize(sequences1, 2).cache() + + val cleanedSequence1 = PrefixSpan.toDatabaseInternalRepr(rdd1, itemToInt1).collect() + + val expected1 = Array(Array(0, 4, 0, 5, 0, 4, 0, 5, 0)) + .map(_.map(x => if (x == 0) 0 else itemToInt1(x) + 1)) + + compareInternalSequences(expected1, cleanedSequence1) + + // Multi-item sequence + val itemToInt2 = (4 to 6).zipWithIndex.toMap + val sequences2 = Seq( + Array(Array(4, 5), Array(1, 6, 2), Array(2), Array(5), Array(2), Array(4), Array(5, 6, 7)), + Array(Array(8, 9), Array(1, 2))) + val rdd2 = sc.parallelize(sequences2, 2).cache() + + val cleanedSequence2 = PrefixSpan.toDatabaseInternalRepr(rdd2, itemToInt2).collect() + + val expected2 = Array(Array(0, 4, 5, 0, 6, 0, 5, 0, 4, 0, 5, 6, 0)) + .map(_.map(x => if (x == 0) 0 else itemToInt2(x) + 1)) + + compareInternalSequences(expected2, cleanedSequence2) + + // Emptied sequence + val itemToInt3 = (10 to 10).zipWithIndex.toMap + val sequences3 = Seq( + Array(Array(4, 5), Array(1, 6, 2), Array(2), Array(5), Array(2), Array(4), Array(5, 6, 7)), + Array(Array(8, 9), Array(1, 2))) + val rdd3 = sc.parallelize(sequences3, 2).cache() + + val cleanedSequence3 = PrefixSpan.toDatabaseInternalRepr(rdd3, itemToInt3).collect() + val expected3 = Array[Array[Int]]() + + compareInternalSequences(expected3, cleanedSequence3) + } + test("model save/load") { val sequences = Seq( Array(Array(1, 2), Array(3)), @@ -409,4 +452,12 @@ class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext { val actualSet = actualValue.map(x => (x._1.toSeq, x._2)).toSet assert(expectedSet === actualSet) } + + private def compareInternalSequences( + expectedValue: Array[Array[Int]], + actualValue: Array[Array[Int]]): Unit = { + val expectedSet = expectedValue.map(x => x.toSeq).toSet + val actualSet = actualValue.map(x => x.toSeq).toSet + assert(expectedSet === actualSet) + } } From a4293c28438515d5ccf1f6b82f7b762e316d0a27 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Thu, 13 Apr 2017 09:56:34 +0100 Subject: [PATCH 279/512] [SPARK-20284][CORE] Make {Des,S}erializationStream extend Closeable ## What changes were proposed in this pull request? This PR allows to use `SerializationStream` and `DeserializationStream` in try-with-resources. ## How was this patch tested? `core` unit tests. Author: Sergei Lebedev Closes #17598 from superbobry/compression-stream-closeable. --- .../scala/org/apache/spark/serializer/Serializer.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala index 01bbda0b5e6b..cb8b1cc07763 100644 --- a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala @@ -125,7 +125,7 @@ abstract class SerializerInstance { * A stream for writing serialized objects. */ @DeveloperApi -abstract class SerializationStream { +abstract class SerializationStream extends Closeable { /** The most general-purpose method to write an object. */ def writeObject[T: ClassTag](t: T): SerializationStream /** Writes the object representing the key of a key-value pair. */ @@ -133,7 +133,7 @@ abstract class SerializationStream { /** Writes the object representing the value of a key-value pair. */ def writeValue[T: ClassTag](value: T): SerializationStream = writeObject(value) def flush(): Unit - def close(): Unit + override def close(): Unit def writeAll[T: ClassTag](iter: Iterator[T]): SerializationStream = { while (iter.hasNext) { @@ -149,14 +149,14 @@ abstract class SerializationStream { * A stream for reading serialized objects. */ @DeveloperApi -abstract class DeserializationStream { +abstract class DeserializationStream extends Closeable { /** The most general-purpose method to read an object. */ def readObject[T: ClassTag](): T /** Reads the object representing the key of a key-value pair. */ def readKey[T: ClassTag](): T = readObject[T]() /** Reads the object representing the value of a key-value pair. */ def readValue[T: ClassTag](): T = readObject[T]() - def close(): Unit + override def close(): Unit /** * Read the elements of this stream through an iterator. This can only be called once, as From fbe4216e1e83d243a7f0521b76bfb20c25278281 Mon Sep 17 00:00:00 2001 From: Ioana Delaney Date: Thu, 13 Apr 2017 22:27:04 +0800 Subject: [PATCH 280/512] [SPARK-20233][SQL] Apply star-join filter heuristics to dynamic programming join enumeration ## What changes were proposed in this pull request? Implements star-join filter to reduce the search space for dynamic programming join enumeration. Consider the following join graph: ``` T1 D1 - T2 - T3 \ / F1 | D2 star-join: {F1, D1, D2} non-star: {T1, T2, T3} ``` The following join combinations will be generated: ``` level 0: (F1), (D1), (D2), (T1), (T2), (T3) level 1: {F1, D1}, {F1, D2}, {T2, T3} level 2: {F1, D1, D2} level 3: {F1, D1, D2, T1}, {F1, D1, D2, T2} level 4: {F1, D1, D2, T1, T2}, {F1, D1, D2, T2, T3 } level 6: {F1, D1, D2, T1, T2, T3} ``` ## How was this patch tested? New test suite ```StarJOinCostBasedReorderSuite.scala```. Author: Ioana Delaney Closes #17546 from ioana-delaney/starSchemaCBOv3. --- .../optimizer/CostBasedJoinReorder.scala | 144 +++++- .../optimizer/StarSchemaDetection.scala | 2 +- .../apache/spark/sql/internal/SQLConf.scala | 8 + .../StarJoinCostBasedReorderSuite.scala | 426 ++++++++++++++++++ 4 files changed, 571 insertions(+), 9 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinCostBasedReorderSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala index cbd506465ae6..c704c2e6d36b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala @@ -54,8 +54,6 @@ case class CostBasedJoinReorder(conf: SQLConf) extends Rule[LogicalPlan] with Pr private def reorder(plan: LogicalPlan, output: Seq[Attribute]): LogicalPlan = { val (items, conditions) = extractInnerJoins(plan) - // TODO: Compute the set of star-joins and use them in the join enumeration - // algorithm to prune un-optimal plan choices. val result = // Do reordering if the number of items is appropriate and join conditions exist. // We also need to check if costs of all items can be evaluated. @@ -150,12 +148,15 @@ object JoinReorderDP extends PredicateHelper with Logging { case (item, id) => Set(id) -> JoinPlan(Set(id), item, Set(), Cost(0, 0)) }.toMap) + // Build filters from the join graph to be used by the search algorithm. + val filters = JoinReorderDPFilters.buildJoinGraphInfo(conf, items, conditions, itemIndex) + // Build plans for next levels until the last level has only one plan. This plan contains // all items that can be joined, so there's no need to continue. val topOutputSet = AttributeSet(output) - while (foundPlans.size < items.length && foundPlans.last.size > 1) { + while (foundPlans.size < items.length) { // Build plans for the next level. - foundPlans += searchLevel(foundPlans, conf, conditions, topOutputSet) + foundPlans += searchLevel(foundPlans, conf, conditions, topOutputSet, filters) } val durationInMs = (System.nanoTime() - startTime) / (1000 * 1000) @@ -179,7 +180,8 @@ object JoinReorderDP extends PredicateHelper with Logging { existingLevels: Seq[JoinPlanMap], conf: SQLConf, conditions: Set[Expression], - topOutput: AttributeSet): JoinPlanMap = { + topOutput: AttributeSet, + filters: Option[JoinGraphInfo]): JoinPlanMap = { val nextLevel = mutable.Map.empty[Set[Int], JoinPlan] var k = 0 @@ -200,7 +202,7 @@ object JoinReorderDP extends PredicateHelper with Logging { } otherSideCandidates.foreach { otherSidePlan => - buildJoin(oneSidePlan, otherSidePlan, conf, conditions, topOutput) match { + buildJoin(oneSidePlan, otherSidePlan, conf, conditions, topOutput, filters) match { case Some(newJoinPlan) => // Check if it's the first plan for the item set, or it's a better plan than // the existing one due to lower cost. @@ -218,14 +220,20 @@ object JoinReorderDP extends PredicateHelper with Logging { } /** - * Builds a new JoinPlan when both conditions hold: + * Builds a new JoinPlan if the following conditions hold: * - the sets of items contained in left and right sides do not overlap. * - there exists at least one join condition involving references from both sides. + * - if star-join filter is enabled, allow the following combinations: + * 1) (oneJoinPlan U otherJoinPlan) is a subset of star-join + * 2) star-join is a subset of (oneJoinPlan U otherJoinPlan) + * 3) (oneJoinPlan U otherJoinPlan) is a subset of non star-join + * * @param oneJoinPlan One side JoinPlan for building a new JoinPlan. * @param otherJoinPlan The other side JoinPlan for building a new join node. * @param conf SQLConf for statistics computation. * @param conditions The overall set of join conditions. * @param topOutput The output attributes of the final plan. + * @param filters Join graph info to be used as filters by the search algorithm. * @return Builds and returns a new JoinPlan if both conditions hold. Otherwise, returns None. */ private def buildJoin( @@ -233,13 +241,27 @@ object JoinReorderDP extends PredicateHelper with Logging { otherJoinPlan: JoinPlan, conf: SQLConf, conditions: Set[Expression], - topOutput: AttributeSet): Option[JoinPlan] = { + topOutput: AttributeSet, + filters: Option[JoinGraphInfo]): Option[JoinPlan] = { if (oneJoinPlan.itemIds.intersect(otherJoinPlan.itemIds).nonEmpty) { // Should not join two overlapping item sets. return None } + if (filters.isDefined) { + // Apply star-join filter, which ensures that tables in a star schema relationship + // are planned together. The star-filter will eliminate joins among star and non-star + // tables until the star joins are built. The following combinations are allowed: + // 1. (oneJoinPlan U otherJoinPlan) is a subset of star-join + // 2. star-join is a subset of (oneJoinPlan U otherJoinPlan) + // 3. (oneJoinPlan U otherJoinPlan) is a subset of non star-join + val isValidJoinCombination = + JoinReorderDPFilters.starJoinFilter(oneJoinPlan.itemIds, otherJoinPlan.itemIds, + filters.get) + if (!isValidJoinCombination) return None + } + val onePlan = oneJoinPlan.plan val otherPlan = otherJoinPlan.plan val joinConds = conditions @@ -327,3 +349,109 @@ object JoinReorderDP extends PredicateHelper with Logging { case class Cost(card: BigInt, size: BigInt) { def +(other: Cost): Cost = Cost(this.card + other.card, this.size + other.size) } + +/** + * Implements optional filters to reduce the search space for join enumeration. + * + * 1) Star-join filters: Plan star-joins together since they are assumed + * to have an optimal execution based on their RI relationship. + * 2) Cartesian products: Defer their planning later in the graph to avoid + * large intermediate results (expanding joins, in general). + * 3) Composite inners: Don't generate "bushy tree" plans to avoid materializing + * intermediate results. + * + * Filters (2) and (3) are not implemented. + */ +object JoinReorderDPFilters extends PredicateHelper { + /** + * Builds join graph information to be used by the filtering strategies. + * Currently, it builds the sets of star/non-star joins. + * It can be extended with the sets of connected/unconnected joins, which + * can be used to filter Cartesian products. + */ + def buildJoinGraphInfo( + conf: SQLConf, + items: Seq[LogicalPlan], + conditions: Set[Expression], + itemIndex: Seq[(LogicalPlan, Int)]): Option[JoinGraphInfo] = { + + if (conf.joinReorderDPStarFilter) { + // Compute the tables in a star-schema relationship. + val starJoin = StarSchemaDetection(conf).findStarJoins(items, conditions.toSeq) + val nonStarJoin = items.filterNot(starJoin.contains(_)) + + if (starJoin.nonEmpty && nonStarJoin.nonEmpty) { + val itemMap = itemIndex.toMap + Some(JoinGraphInfo(starJoin.map(itemMap).toSet, nonStarJoin.map(itemMap).toSet)) + } else { + // Nothing interesting to return. + None + } + } else { + // Star schema filter is not enabled. + None + } + } + + /** + * Applies the star-join filter that eliminates join combinations among star + * and non-star tables until the star join is built. + * + * Given the oneSideJoinPlan/otherSideJoinPlan, which represent all the plan + * permutations generated by the DP join enumeration, and the star/non-star plans, + * the following plan combinations are allowed: + * 1. (oneSideJoinPlan U otherSideJoinPlan) is a subset of star-join + * 2. star-join is a subset of (oneSideJoinPlan U otherSideJoinPlan) + * 3. (oneSideJoinPlan U otherSideJoinPlan) is a subset of non star-join + * + * It assumes the sets are disjoint. + * + * Example query graph: + * + * t1 d1 - t2 - t3 + * \ / + * f1 + * | + * d2 + * + * star: {d1, f1, d2} + * non-star: {t2, t1, t3} + * + * level 0: (f1 ), (d2 ), (t3 ), (d1 ), (t1 ), (t2 ) + * level 1: {t3 t2 }, {f1 d2 }, {f1 d1 } + * level 2: {d2 f1 d1 } + * level 3: {t1 d1 f1 d2 }, {t2 d1 f1 d2 } + * level 4: {d1 t2 f1 t1 d2 }, {d1 t3 t2 f1 d2 } + * level 5: {d1 t3 t2 f1 t1 d2 } + * + * @param oneSideJoinPlan One side of the join represented as a set of plan ids. + * @param otherSideJoinPlan The other side of the join represented as a set of plan ids. + * @param filters Star and non-star plans represented as sets of plan ids + */ + def starJoinFilter( + oneSideJoinPlan: Set[Int], + otherSideJoinPlan: Set[Int], + filters: JoinGraphInfo) : Boolean = { + val starJoins = filters.starJoins + val nonStarJoins = filters.nonStarJoins + val join = oneSideJoinPlan.union(otherSideJoinPlan) + + // Disjoint sets + oneSideJoinPlan.intersect(otherSideJoinPlan).isEmpty && + // Either star or non-star is empty + (starJoins.isEmpty || nonStarJoins.isEmpty || + // Join is a subset of the star-join + join.subsetOf(starJoins) || + // Star-join is a subset of join + starJoins.subsetOf(join) || + // Join is a subset of non-star + join.subsetOf(nonStarJoins)) + } +} + +/** + * Helper class that keeps information about the join graph as sets of item/plan ids. + * It currently stores the star/non-star plans. It can be + * extended with the set of connected/unconnected plans. + */ +case class JoinGraphInfo (starJoins: Set[Int], nonStarJoins: Set[Int]) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/StarSchemaDetection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/StarSchemaDetection.scala index 91cb004eaec4..97ee9988386d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/StarSchemaDetection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/StarSchemaDetection.scala @@ -76,7 +76,7 @@ case class StarSchemaDetection(conf: SQLConf) extends PredicateHelper { val emptyStarJoinPlan = Seq.empty[LogicalPlan] - if (!conf.starSchemaDetection || input.size < 2) { + if (input.size < 2) { emptyStarJoinPlan } else { // Find if the input plans are eligible for star join detection. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 6b0f49503349..2e1798e22b9f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -736,6 +736,12 @@ object SQLConf { .checkValue(weight => weight >= 0 && weight <= 1, "The weight value must be in [0, 1].") .createWithDefault(0.7) + val JOIN_REORDER_DP_STAR_FILTER = + buildConf("spark.sql.cbo.joinReorder.dp.star.filter") + .doc("Applies star-join filter heuristics to cost based join enumeration.") + .booleanConf + .createWithDefault(false) + val STARSCHEMA_DETECTION = buildConf("spark.sql.cbo.starSchemaDetection") .doc("When true, it enables join reordering based on star schema detection. ") .booleanConf @@ -1011,6 +1017,8 @@ class SQLConf extends Serializable with Logging { def joinReorderCardWeight: Double = getConf(SQLConf.JOIN_REORDER_CARD_WEIGHT) + def joinReorderDPStarFilter: Boolean = getConf(SQLConf.JOIN_REORDER_DP_STAR_FILTER) + def windowExecBufferSpillThreshold: Int = getConf(WINDOW_EXEC_BUFFER_SPILL_THRESHOLD) def sortMergeJoinExecBufferSpillThreshold: Int = diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinCostBasedReorderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinCostBasedReorderSuite.scala new file mode 100644 index 000000000000..a23d6266b284 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinCostBasedReorderSuite.scala @@ -0,0 +1,426 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap} +import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest} +import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.catalyst.statsEstimation.{StatsEstimationTestBase, StatsTestPlan} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.SQLConf._ + + +class StarJoinCostBasedReorderSuite extends PlanTest with StatsEstimationTestBase { + + override val conf = new SQLConf().copy( + CBO_ENABLED -> true, + JOIN_REORDER_ENABLED -> true, + JOIN_REORDER_DP_STAR_FILTER -> true) + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("Operator Optimizations", FixedPoint(100), + CombineFilters, + PushDownPredicate, + ReorderJoin(conf), + PushPredicateThroughJoin, + ColumnPruning, + CollapseProject) :: + Batch("Join Reorder", Once, + CostBasedJoinReorder(conf)) :: Nil + } + + private val columnInfo: AttributeMap[ColumnStat] = AttributeMap(Seq( + // F1 (fact table) + attr("f1_fk1") -> ColumnStat(distinctCount = 100, min = Some(1), max = Some(100), + nullCount = 0, avgLen = 4, maxLen = 4), + attr("f1_fk2") -> ColumnStat(distinctCount = 100, min = Some(1), max = Some(100), + nullCount = 0, avgLen = 4, maxLen = 4), + attr("f1_fk3") -> ColumnStat(distinctCount = 100, min = Some(1), max = Some(100), + nullCount = 0, avgLen = 4, maxLen = 4), + attr("f1_c1") -> ColumnStat(distinctCount = 100, min = Some(1), max = Some(100), + nullCount = 0, avgLen = 4, maxLen = 4), + attr("f1_c2") -> ColumnStat(distinctCount = 100, min = Some(1), max = Some(100), + nullCount = 0, avgLen = 4, maxLen = 4), + + // D1 (dimension) + attr("d1_pk") -> ColumnStat(distinctCount = 100, min = Some(1), max = Some(100), + nullCount = 0, avgLen = 4, maxLen = 4), + attr("d1_c2") -> ColumnStat(distinctCount = 50, min = Some(1), max = Some(50), + nullCount = 0, avgLen = 4, maxLen = 4), + attr("d1_c3") -> ColumnStat(distinctCount = 50, min = Some(1), max = Some(50), + nullCount = 0, avgLen = 4, maxLen = 4), + + // D2 (dimension) + attr("d2_pk") -> ColumnStat(distinctCount = 20, min = Some(1), max = Some(20), + nullCount = 0, avgLen = 4, maxLen = 4), + attr("d2_c2") -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), + nullCount = 0, avgLen = 4, maxLen = 4), + attr("d2_c3") -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), + nullCount = 0, avgLen = 4, maxLen = 4), + + // D3 (dimension) + attr("d3_pk") -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), + nullCount = 0, avgLen = 4, maxLen = 4), + attr("d3_c2") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), + nullCount = 0, avgLen = 4, maxLen = 4), + attr("d3_c3") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), + nullCount = 0, avgLen = 4, maxLen = 4), + + // T1 (regular table i.e. outside star) + attr("t1_c1") -> ColumnStat(distinctCount = 20, min = Some(1), max = Some(20), + nullCount = 1, avgLen = 4, maxLen = 4), + attr("t1_c2") -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), + nullCount = 1, avgLen = 4, maxLen = 4), + attr("t1_c3") -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), + nullCount = 1, avgLen = 4, maxLen = 4), + + // T2 (regular table) + attr("t2_c1") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), + nullCount = 1, avgLen = 4, maxLen = 4), + attr("t2_c2") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), + nullCount = 1, avgLen = 4, maxLen = 4), + attr("t2_c3") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), + nullCount = 1, avgLen = 4, maxLen = 4), + + // T3 (regular table) + attr("t3_c1") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), + nullCount = 1, avgLen = 4, maxLen = 4), + attr("t3_c2") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), + nullCount = 1, avgLen = 4, maxLen = 4), + attr("t3_c3") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), + nullCount = 1, avgLen = 4, maxLen = 4), + + // T4 (regular table) + attr("t4_c1") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), + nullCount = 1, avgLen = 4, maxLen = 4), + attr("t4_c2") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), + nullCount = 1, avgLen = 4, maxLen = 4), + attr("t4_c3") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), + nullCount = 1, avgLen = 4, maxLen = 4), + + // T5 (regular table) + attr("t5_c1") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), + nullCount = 1, avgLen = 4, maxLen = 4), + attr("t5_c2") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), + nullCount = 1, avgLen = 4, maxLen = 4), + attr("t5_c3") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), + nullCount = 1, avgLen = 4, maxLen = 4), + + // T6 (regular table) + attr("t6_c1") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), + nullCount = 1, avgLen = 4, maxLen = 4), + attr("t6_c2") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), + nullCount = 1, avgLen = 4, maxLen = 4), + attr("t6_c3") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), + nullCount = 1, avgLen = 4, maxLen = 4) + + )) + + private val nameToAttr: Map[String, Attribute] = columnInfo.map(kv => kv._1.name -> kv._1) + private val nameToColInfo: Map[String, (Attribute, ColumnStat)] = + columnInfo.map(kv => kv._1.name -> kv) + + private val f1 = StatsTestPlan( + outputList = Seq("f1_fk1", "f1_fk2", "f1_fk3", "f1_c1", "f1_c2").map(nameToAttr), + rowCount = 1000, + size = Some(1000 * (8 + 4 * 5)), + attributeStats = AttributeMap(Seq("f1_fk1", "f1_fk2", "f1_fk3", "f1_c1", "f1_c2") + .map(nameToColInfo))) + + // To control the layout of the join plans, keep the size for the non-fact tables constant + // and vary the rowcount and the number of distinct values of the join columns. + private val d1 = StatsTestPlan( + outputList = Seq("d1_pk", "d1_c2", "d1_c3").map(nameToAttr), + rowCount = 100, + size = Some(3000), + attributeStats = AttributeMap(Seq("d1_pk", "d1_c2", "d1_c3").map(nameToColInfo))) + + private val d2 = StatsTestPlan( + outputList = Seq("d2_pk", "d2_c2", "d2_c3").map(nameToAttr), + rowCount = 20, + size = Some(3000), + attributeStats = AttributeMap(Seq("d2_pk", "d2_c2", "d2_c3").map(nameToColInfo))) + + private val d3 = StatsTestPlan( + outputList = Seq("d3_pk", "d3_c2", "d3_c3").map(nameToAttr), + rowCount = 10, + size = Some(3000), + attributeStats = AttributeMap(Seq("d3_pk", "d3_c2", "d3_c3").map(nameToColInfo))) + + private val t1 = StatsTestPlan( + outputList = Seq("t1_c1", "t1_c2", "t1_c3").map(nameToAttr), + rowCount = 50, + size = Some(3000), + attributeStats = AttributeMap(Seq("t1_c1", "t1_c2", "t1_c3").map(nameToColInfo))) + + private val t2 = StatsTestPlan( + outputList = Seq("t2_c1", "t2_c2", "t2_c3").map(nameToAttr), + rowCount = 10, + size = Some(3000), + attributeStats = AttributeMap(Seq("t2_c1", "t2_c2", "t2_c3").map(nameToColInfo))) + + private val t3 = StatsTestPlan( + outputList = Seq("t3_c1", "t3_c2", "t3_c3").map(nameToAttr), + rowCount = 10, + size = Some(3000), + attributeStats = AttributeMap(Seq("t3_c1", "t3_c2", "t3_c3").map(nameToColInfo))) + + private val t4 = StatsTestPlan( + outputList = Seq("t4_c1", "t4_c2", "t4_c3").map(nameToAttr), + rowCount = 10, + size = Some(3000), + attributeStats = AttributeMap(Seq("t4_c1", "t4_c2", "t4_c3").map(nameToColInfo))) + + private val t5 = StatsTestPlan( + outputList = Seq("t5_c1", "t5_c2", "t5_c3").map(nameToAttr), + rowCount = 10, + size = Some(3000), + attributeStats = AttributeMap(Seq("t5_c1", "t5_c2", "t5_c3").map(nameToColInfo))) + + private val t6 = StatsTestPlan( + outputList = Seq("t6_c1", "t6_c2", "t6_c3").map(nameToAttr), + rowCount = 10, + size = Some(3000), + attributeStats = AttributeMap(Seq("t6_c1", "t6_c2", "t6_c3").map(nameToColInfo))) + + test("Test 1: Star query with two dimensions and two regular tables") { + + // d1 t1 + // \ / + // f1 + // / \ + // d2 t2 + // + // star: {f1, d1, d2} + // non-star: {t1, t2} + // + // level 0: (t2 ), (d2 ), (f1 ), (d1 ), (t1 ) + // level 1: {f1 d1 }, {d2 f1 } + // level 2: {d2 f1 d1 } + // level 3: {t2 d1 d2 f1 }, {t1 d1 d2 f1 } + // level 4: {f1 t1 t2 d1 d2 } + // + // Number of generated plans: 11 (vs. 20 w/o filter) + val query = + f1.join(t1).join(t2).join(d1).join(d2) + .where((nameToAttr("f1_c1") === nameToAttr("t1_c1")) && + (nameToAttr("f1_c2") === nameToAttr("t2_c1")) && + (nameToAttr("f1_fk1") === nameToAttr("d1_pk")) && + (nameToAttr("f1_fk2") === nameToAttr("d2_pk"))) + + val expected = + f1.join(d2, Inner, Some(nameToAttr("f1_fk2") === nameToAttr("d2_pk"))) + .join(d1, Inner, Some(nameToAttr("f1_fk1") === nameToAttr("d1_pk"))) + .join(t2, Inner, Some(nameToAttr("f1_c2") === nameToAttr("t2_c1"))) + .join(t1, Inner, Some(nameToAttr("f1_c1") === nameToAttr("t1_c1"))) + + assertEqualPlans(query, expected) + } + + test("Test 2: Star with a linear branch") { + // + // t1 d1 - t2 - t3 + // \ / + // f1 + // | + // d2 + // + // star: {d1, f1, d2} + // non-star: {t2, t1, t3} + // + // level 0: (f1 ), (d2 ), (t3 ), (d1 ), (t1 ), (t2 ) + // level 1: {t3 t2 }, {f1 d2 }, {f1 d1 } + // level 2: {d2 f1 d1 } + // level 3: {t1 d1 f1 d2 }, {t2 d1 f1 d2 } + // level 4: {d1 t2 f1 t1 d2 }, {d1 t3 t2 f1 d2 } + // level 5: {d1 t3 t2 f1 t1 d2 } + // + // Number of generated plans: 15 (vs 24) + val query = + d1.join(t1).join(t2).join(f1).join(d2).join(t3) + .where((nameToAttr("d1_pk") === nameToAttr("f1_fk1")) && + (nameToAttr("t1_c1") === nameToAttr("f1_c1")) && + (nameToAttr("d2_pk") === nameToAttr("f1_fk2")) && + (nameToAttr("f1_fk2") === nameToAttr("d2_pk")) && + (nameToAttr("d1_c2") === nameToAttr("t2_c1")) && + (nameToAttr("t2_c2") === nameToAttr("t3_c1"))) + + val expected = + f1.join(d2, Inner, Some(nameToAttr("f1_fk2") === nameToAttr("d2_pk"))) + .join(d1, Inner, Some(nameToAttr("f1_fk1") === nameToAttr("d1_pk"))) + .join(t3.join(t2, Inner, Some(nameToAttr("t2_c2") === nameToAttr("t3_c1"))), Inner, + Some(nameToAttr("d1_c2") === nameToAttr("t2_c1"))) + .join(t1, Inner, Some(nameToAttr("t1_c1") === nameToAttr("f1_c1"))) + + assertEqualPlans(query, expected) + } + + test("Test 3: Star with derived branches") { + // t3 t2 + // | | + // d1 - t4 - t1 + // | + // f1 + // | + // d2 + // + // star: (d1 f1 d2 ) + // non-star: (t4 t1 t2 t3 ) + // + // level 0: (t1 ), (t3 ), (f1 ), (d1 ), (t2 ), (d2 ), (t4 ) + // level 1: {f1 d2 }, {t1 t4 }, {t1 t2 }, {f1 d1 }, {t3 t4 } + // level 2: {d1 f1 d2 }, {t2 t1 t4 }, {t1 t3 t4 } + // level 3: {t4 d1 f1 d2 }, {t3 t4 t1 t2 } + // level 4: {d1 f1 t4 d2 t3 }, {d1 f1 t4 d2 t1 } + // level 5: {d1 f1 t4 d2 t1 t2 }, {d1 f1 t4 d2 t1 t3 } + // level 6: {d1 f1 t4 d2 t1 t2 t3 } + // + // Number of generated plans: 22 (vs. 34) + val query = + d1.join(t1).join(t2).join(t3).join(t4).join(f1).join(d2) + .where((nameToAttr("t1_c1") === nameToAttr("t2_c1")) && + (nameToAttr("t3_c1") === nameToAttr("t4_c1")) && + (nameToAttr("t1_c2") === nameToAttr("t4_c2")) && + (nameToAttr("d1_c2") === nameToAttr("t4_c3")) && + (nameToAttr("f1_fk1") === nameToAttr("d1_pk")) && + (nameToAttr("f1_fk2") === nameToAttr("d2_pk"))) + + val expected = + f1.join(d2, Inner, Some(nameToAttr("f1_fk2") === nameToAttr("d2_pk"))) + .join(d1, Inner, Some(nameToAttr("f1_fk1") === nameToAttr("d1_pk"))) + .join(t3.join(t4, Inner, Some(nameToAttr("t3_c1") === nameToAttr("t4_c1"))), Inner, + Some(nameToAttr("t3_c1") === nameToAttr("t4_c1"))) + .join(t1.join(t2, Inner, Some(nameToAttr("t1_c1") === nameToAttr("t2_c1"))), Inner, + Some(nameToAttr("t1_c2") === nameToAttr("t4_c2"))) + + assertEqualPlans(query, expected) + } + + test("Test 4: Star with several branches") { + // + // d1 - t3 - t4 + // | + // f1 - d3 - t1 - t2 + // | + // d2 - t5 - t6 + // + // star: {d1 f1 d2 d3 } + // non-star: {t5 t3 t6 t2 t4 t1} + // + // level 0: (t4 ), (d2 ), (t5 ), (d3 ), (d1 ), (f1 ), (t2 ), (t6 ), (t1 ), (t3 ) + // level 1: {t5 t6 }, {t4 t3 }, {d3 f1 }, {t2 t1 }, {d2 f1 }, {d1 f1 } + // level 2: {d2 d1 f1 }, {d2 d3 f1 }, {d3 d1 f1 } + // level 3: {d2 d1 d3 f1 } + // level 4: {d1 t3 d3 f1 d2 }, {d1 d3 f1 t1 d2 }, {d1 t5 d3 f1 d2 } + // level 5: {d1 t5 d3 f1 t1 d2 }, {d1 t3 t4 d3 f1 d2 }, {d1 t5 t6 d3 f1 d2 }, + // {d1 t5 t3 d3 f1 d2 }, {d1 t3 d3 f1 t1 d2 }, {d1 t2 d3 f1 t1 d2 } + // level 6: {d1 t5 t3 t4 d3 f1 d2 }, {d1 t3 t2 d3 f1 t1 d2 }, {d1 t5 t6 d3 f1 t1 d2 }, + // {d1 t5 t3 d3 f1 t1 d2 }, {d1 t5 t2 d3 f1 t1 d2 }, ... + // ... + // level 9: {d1 t5 t3 t6 t2 t4 d3 f1 t1 d2 } + // + // Number of generated plans: 46 (vs. 82) + val query = + d1.join(t3).join(t4).join(f1).join(d2).join(t5).join(t6).join(d3).join(t1).join(t2) + .where((nameToAttr("d1_c2") === nameToAttr("t3_c1")) && + (nameToAttr("t3_c2") === nameToAttr("t4_c2")) && + (nameToAttr("d1_pk") === nameToAttr("f1_fk1")) && + (nameToAttr("f1_fk2") === nameToAttr("d2_pk")) && + (nameToAttr("d2_c2") === nameToAttr("t5_c1")) && + (nameToAttr("t5_c2") === nameToAttr("t6_c2")) && + (nameToAttr("f1_fk3") === nameToAttr("d3_pk")) && + (nameToAttr("d3_c2") === nameToAttr("t1_c1")) && + (nameToAttr("t1_c2") === nameToAttr("t2_c2"))) + + val expected = + f1.join(d3, Inner, Some(nameToAttr("f1_fk3") === nameToAttr("d3_pk"))) + .join(d1, Inner, Some(nameToAttr("f1_fk1") === nameToAttr("d1_pk"))) + .join(d2, Inner, Some(nameToAttr("f1_fk2") === nameToAttr("d2_pk"))) + .join(t4.join(t3, Inner, Some(nameToAttr("t3_c2") === nameToAttr("t4_c2"))), Inner, + Some(nameToAttr("d1_c2") === nameToAttr("t3_c1"))) + .join(t2.join(t1, Inner, Some(nameToAttr("t1_c2") === nameToAttr("t2_c2"))), Inner, + Some(nameToAttr("d3_c2") === nameToAttr("t1_c1"))) + .join(t5.join(t6, Inner, Some(nameToAttr("t5_c2") === nameToAttr("t6_c2"))), Inner, + Some(nameToAttr("d2_c2") === nameToAttr("t5_c1"))) + + assertEqualPlans(query, expected) + } + + test("Test 5: RI star only") { + // d1 + // | + // f1 + // / \ + // d2 d3 + // + // star: {f1, d1, d2, d3} + // non-star: {} + // level 0: (d1), (f1), (d2), (d3) + // level 1: {f1 d3 }, {f1 d2 }, {d1 f1 } + // level 2: {d1 f1 d2 }, {d2 f1 d3 }, {d1 f1 d3 } + // level 3: {d1 d2 f1 d3 } + // Number of generated plans: 11 (= 11) + val query = + d1.join(d2).join(f1).join(d3) + .where((nameToAttr("f1_fk1") === nameToAttr("d1_pk")) && + (nameToAttr("f1_fk2") === nameToAttr("d2_pk")) && + (nameToAttr("f1_fk3") === nameToAttr("d3_pk"))) + + val expected = + f1.join(d3, Inner, Some(nameToAttr("f1_fk3") === nameToAttr("d3_pk"))) + .join(d2, Inner, Some(nameToAttr("f1_fk2") === nameToAttr("d2_pk"))) + .join(d1, Inner, Some(nameToAttr("f1_fk1") === nameToAttr("d1_pk"))) + + assertEqualPlans(query, expected) + } + + test("Test 6: No RI star") { + // + // f1 - t1 - t2 - t3 + // + // star: {} + // non-star: {f1, t1, t2, t3} + // level 0: (t1), (f1), (t2), (t3) + // level 1: {f1 t3 }, {f1 t2 }, {t1 f1 } + // level 2: {t1 f1 t2 }, {t2 f1 t3 }, {dt f1 t3 } + // level 3: {t1 t2 f1 t3 } + // Number of generated plans: 11 (= 11) + val query = + t1.join(f1).join(t2).join(t3) + .where((nameToAttr("f1_fk1") === nameToAttr("t1_c1")) && + (nameToAttr("f1_fk2") === nameToAttr("t2_c1")) && + (nameToAttr("f1_fk3") === nameToAttr("t3_c1"))) + + val expected = + f1.join(t3, Inner, Some(nameToAttr("f1_fk3") === nameToAttr("t3_c1"))) + .join(t2, Inner, Some(nameToAttr("f1_fk2") === nameToAttr("t2_c1"))) + .join(t1, Inner, Some(nameToAttr("f1_fk1") === nameToAttr("t1_c1"))) + + assertEqualPlans(query, expected) + } + + private def assertEqualPlans( plan1: LogicalPlan, plan2: LogicalPlan): Unit = { + val optimized = Optimize.execute(plan1.analyze) + val expected = plan2.analyze + compareJoinOrder(optimized, expected) + } +} From 8ddf0d2a60795a2306f94df8eac6e265b1fe5230 Mon Sep 17 00:00:00 2001 From: David Gingrich Date: Thu, 13 Apr 2017 12:43:28 -0700 Subject: [PATCH 281/512] [SPARK-20232][PYTHON] Improve combineByKey docs MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? Improve combineByKey documentation: * Add note on memory allocation * Change example code to use different mergeValue and mergeCombiners ## How was this patch tested? Doctest. ## Legal This is my original work and I license the work to the project under the project’s open source license. Author: David Gingrich Closes #17545 from dgingrich/topic-spark-20232-combinebykey-docs. --- python/pyspark/rdd.py | 24 +++++++++++++++++++----- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 291c1caaaed5..60141792d499 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -1804,17 +1804,31 @@ def combineByKey(self, createCombiner, mergeValue, mergeCombiners, a one-element list) - C{mergeValue}, to merge a V into a C (e.g., adds it to the end of a list) - - C{mergeCombiners}, to combine two C's into a single one. + - C{mergeCombiners}, to combine two C's into a single one (e.g., merges + the lists) + + To avoid memory allocation, both mergeValue and mergeCombiners are allowed to + modify and return their first argument instead of creating a new C. In addition, users can control the partitioning of the output RDD. .. note:: V and C can be different -- for example, one might group an RDD of type (Int, Int) into an RDD of type (Int, List[Int]). - >>> x = sc.parallelize([("a", 1), ("b", 1), ("a", 1)]) - >>> def add(a, b): return a + str(b) - >>> sorted(x.combineByKey(str, add, add).collect()) - [('a', '11'), ('b', '1')] + >>> x = sc.parallelize([("a", 1), ("b", 1), ("a", 2)]) + >>> def to_list(a): + ... return [a] + ... + >>> def append(a, b): + ... a.append(b) + ... return a + ... + >>> def extend(a, b): + ... a.extend(b) + ... return a + ... + >>> sorted(x.combineByKey(to_list, append, extend).collect()) + [('a', [1, 2]), ('b', [1])] """ if numPartitions is None: numPartitions = self._defaultReducePartitions() From 7536e2849df6d63587fbf16b4ecb5db06fed7125 Mon Sep 17 00:00:00 2001 From: Steve Loughran Date: Thu, 13 Apr 2017 15:30:44 -0500 Subject: [PATCH 282/512] [SPARK-20038][SQL] FileFormatWriter.ExecuteWriteTask.releaseResources() implementations to be re-entrant ## What changes were proposed in this pull request? have the`FileFormatWriter.ExecuteWriteTask.releaseResources()` implementations set `currentWriter=null` in a finally clause. This guarantees that if the first call to `currentWriter()` throws an exception, the second releaseResources() call made during the task cancel process will not trigger a second attempt to close the stream. ## How was this patch tested? Tricky. I've been fixing the underlying cause when I saw the problem [HADOOP-14204](https://issues.apache.org/jira/browse/HADOOP-14204), but SPARK-10109 shows I'm not the first to have seen this. I can't replicate it locally any more, my code no longer being broken. code review, however, should be straightforward Author: Steve Loughran Closes #17364 from steveloughran/stevel/SPARK-20038-close. --- .../execution/datasources/FileFormatWriter.scala | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala index bda64d4b91bb..4ec09bff429c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala @@ -324,8 +324,11 @@ object FileFormatWriter extends Logging { override def releaseResources(): Unit = { if (currentWriter != null) { - currentWriter.close() - currentWriter = null + try { + currentWriter.close() + } finally { + currentWriter = null + } } } } @@ -459,8 +462,11 @@ object FileFormatWriter extends Logging { override def releaseResources(): Unit = { if (currentWriter != null) { - currentWriter.close() - currentWriter = null + try { + currentWriter.close() + } finally { + currentWriter = null + } } } } From fb036c4413c2cd4d90880d080f418ec468d6c0fc Mon Sep 17 00:00:00 2001 From: wangzhenhua Date: Fri, 14 Apr 2017 19:16:47 +0800 Subject: [PATCH 283/512] [SPARK-20318][SQL] Use Catalyst type for min/max in ColumnStat for ease of estimation ## What changes were proposed in this pull request? Currently when estimating predicates like col > literal or col = literal, we will update min or max in column stats based on literal value. However, literal value is of Catalyst type (internal type), while min/max is of external type. Then for the next predicate, we again need to do type conversion to compare and update column stats. This is awkward and causes many unnecessary conversions in estimation. To solve this, we use Catalyst type for min/max in `ColumnStat`. Note that the persistent format in metastore is still of external type, so there's no inconsistency for statistics in metastore. This pr also fixes a bug for boolean type in `IN` condition. ## How was this patch tested? The changes for ColumnStat are covered by existing tests. For bug fix, a new test for boolean type in IN condition is added Author: wangzhenhua Closes #17630 from wzhfy/refactorColumnStat. --- .../catalyst/plans/logical/Statistics.scala | 95 +++++++++++++------ .../statsEstimation/EstimationUtils.scala | 30 +++++- .../statsEstimation/FilterEstimation.scala | 68 ++++--------- .../plans/logical/statsEstimation/Range.scala | 70 +++----------- .../FilterEstimationSuite.scala | 41 ++++---- .../statsEstimation/JoinEstimationSuite.scala | 15 +-- .../ProjectEstimationSuite.scala | 21 ++-- .../command/AnalyzeColumnCommand.scala | 8 +- .../spark/sql/StatisticsCollectionSuite.scala | 19 ++-- .../spark/sql/hive/HiveExternalCatalog.scala | 4 +- 10 files changed, 189 insertions(+), 182 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala index f24b240956a6..3d4efef953a6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala @@ -25,6 +25,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.{AnalysisException, Row} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -74,11 +75,10 @@ case class Statistics( * Statistics collected for a column. * * 1. Supported data types are defined in `ColumnStat.supportsType`. - * 2. The JVM data type stored in min/max is the external data type (used in Row) for the - * corresponding Catalyst data type. For example, for DateType we store java.sql.Date, and for - * TimestampType we store java.sql.Timestamp. - * 3. For integral types, they are all upcasted to longs, i.e. shorts are stored as longs. - * 4. There is no guarantee that the statistics collected are accurate. Approximation algorithms + * 2. The JVM data type stored in min/max is the internal data type for the corresponding + * Catalyst data type. For example, the internal type of DateType is Int, and that the internal + * type of TimestampType is Long. + * 3. There is no guarantee that the statistics collected are accurate. Approximation algorithms * (sketches) might have been used, and the data collected can also be stale. * * @param distinctCount number of distinct values @@ -104,22 +104,43 @@ case class ColumnStat( /** * Returns a map from string to string that can be used to serialize the column stats. * The key is the name of the field (e.g. "distinctCount" or "min"), and the value is the string - * representation for the value. The deserialization side is defined in [[ColumnStat.fromMap]]. + * representation for the value. min/max values are converted to the external data type. For + * example, for DateType we store java.sql.Date, and for TimestampType we store + * java.sql.Timestamp. The deserialization side is defined in [[ColumnStat.fromMap]]. * * As part of the protocol, the returned map always contains a key called "version". * In the case min/max values are null (None), they won't appear in the map. */ - def toMap: Map[String, String] = { + def toMap(colName: String, dataType: DataType): Map[String, String] = { val map = new scala.collection.mutable.HashMap[String, String] map.put(ColumnStat.KEY_VERSION, "1") map.put(ColumnStat.KEY_DISTINCT_COUNT, distinctCount.toString) map.put(ColumnStat.KEY_NULL_COUNT, nullCount.toString) map.put(ColumnStat.KEY_AVG_LEN, avgLen.toString) map.put(ColumnStat.KEY_MAX_LEN, maxLen.toString) - min.foreach { v => map.put(ColumnStat.KEY_MIN_VALUE, v.toString) } - max.foreach { v => map.put(ColumnStat.KEY_MAX_VALUE, v.toString) } + min.foreach { v => map.put(ColumnStat.KEY_MIN_VALUE, toExternalString(v, colName, dataType)) } + max.foreach { v => map.put(ColumnStat.KEY_MAX_VALUE, toExternalString(v, colName, dataType)) } map.toMap } + + /** + * Converts the given value from Catalyst data type to string representation of external + * data type. + */ + private def toExternalString(v: Any, colName: String, dataType: DataType): String = { + val externalValue = dataType match { + case DateType => DateTimeUtils.toJavaDate(v.asInstanceOf[Int]) + case TimestampType => DateTimeUtils.toJavaTimestamp(v.asInstanceOf[Long]) + case BooleanType | _: IntegralType | FloatType | DoubleType => v + case _: DecimalType => v.asInstanceOf[Decimal].toJavaBigDecimal + // This version of Spark does not use min/max for binary/string types so we ignore it. + case _ => + throw new AnalysisException("Column statistics deserialization is not supported for " + + s"column $colName of data type: $dataType.") + } + externalValue.toString + } + } @@ -150,28 +171,15 @@ object ColumnStat extends Logging { * Creates a [[ColumnStat]] object from the given map. This is used to deserialize column stats * from some external storage. The serialization side is defined in [[ColumnStat.toMap]]. */ - def fromMap(table: String, field: StructField, map: Map[String, String]) - : Option[ColumnStat] = { - val str2val: (String => Any) = field.dataType match { - case _: IntegralType => _.toLong - case _: DecimalType => new java.math.BigDecimal(_) - case DoubleType | FloatType => _.toDouble - case BooleanType => _.toBoolean - case DateType => java.sql.Date.valueOf - case TimestampType => java.sql.Timestamp.valueOf - // This version of Spark does not use min/max for binary/string types so we ignore it. - case BinaryType | StringType => _ => null - case _ => - throw new AnalysisException("Column statistics deserialization is not supported for " + - s"column ${field.name} of data type: ${field.dataType}.") - } - + def fromMap(table: String, field: StructField, map: Map[String, String]): Option[ColumnStat] = { try { Some(ColumnStat( distinctCount = BigInt(map(KEY_DISTINCT_COUNT).toLong), // Note that flatMap(Option.apply) turns Option(null) into None. - min = map.get(KEY_MIN_VALUE).map(str2val).flatMap(Option.apply), - max = map.get(KEY_MAX_VALUE).map(str2val).flatMap(Option.apply), + min = map.get(KEY_MIN_VALUE) + .map(fromExternalString(_, field.name, field.dataType)).flatMap(Option.apply), + max = map.get(KEY_MAX_VALUE) + .map(fromExternalString(_, field.name, field.dataType)).flatMap(Option.apply), nullCount = BigInt(map(KEY_NULL_COUNT).toLong), avgLen = map.getOrElse(KEY_AVG_LEN, field.dataType.defaultSize.toString).toLong, maxLen = map.getOrElse(KEY_MAX_LEN, field.dataType.defaultSize.toString).toLong @@ -183,6 +191,30 @@ object ColumnStat extends Logging { } } + /** + * Converts from string representation of external data type to the corresponding Catalyst data + * type. + */ + private def fromExternalString(s: String, name: String, dataType: DataType): Any = { + dataType match { + case BooleanType => s.toBoolean + case DateType => DateTimeUtils.fromJavaDate(java.sql.Date.valueOf(s)) + case TimestampType => DateTimeUtils.fromJavaTimestamp(java.sql.Timestamp.valueOf(s)) + case ByteType => s.toByte + case ShortType => s.toShort + case IntegerType => s.toInt + case LongType => s.toLong + case FloatType => s.toFloat + case DoubleType => s.toDouble + case _: DecimalType => Decimal(s) + // This version of Spark does not use min/max for binary/string types so we ignore it. + case BinaryType | StringType => null + case _ => + throw new AnalysisException("Column statistics deserialization is not supported for " + + s"column $name of data type: $dataType.") + } + } + /** * Constructs an expression to compute column statistics for a given column. * @@ -232,11 +264,14 @@ object ColumnStat extends Logging { } /** Convert a struct for column stats (defined in statExprs) into [[ColumnStat]]. */ - def rowToColumnStat(row: Row): ColumnStat = { + def rowToColumnStat(row: Row, attr: Attribute): ColumnStat = { ColumnStat( distinctCount = BigInt(row.getLong(0)), - min = Option(row.get(1)), // for string/binary min/max, get should return null - max = Option(row.get(2)), + // for string/binary min/max, get should return null + min = Option(row.get(1)) + .map(v => fromExternalString(v.toString, attr.name, attr.dataType)).flatMap(Option.apply), + max = Option(row.get(2)) + .map(v => fromExternalString(v.toString, attr.name, attr.dataType)).flatMap(Option.apply), nullCount = BigInt(row.getLong(3)), avgLen = row.getLong(4), maxLen = row.getLong(5) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala index 5577233ffa6f..f1aff62cb6af 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala @@ -22,7 +22,7 @@ import scala.math.BigDecimal.RoundingMode import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap} import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LogicalPlan, Statistics} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{DataType, StringType} +import org.apache.spark.sql.types.{DecimalType, _} object EstimationUtils { @@ -75,4 +75,32 @@ object EstimationUtils { // (simple computation of statistics returns product of children). if (outputRowCount > 0) outputRowCount * sizePerRow else 1 } + + /** + * For simplicity we use Decimal to unify operations for data types whose min/max values can be + * represented as numbers, e.g. Boolean can be represented as 0 (false) or 1 (true). + * The two methods below are the contract of conversion. + */ + def toDecimal(value: Any, dataType: DataType): Decimal = { + dataType match { + case _: NumericType | DateType | TimestampType => Decimal(value.toString) + case BooleanType => if (value.asInstanceOf[Boolean]) Decimal(1) else Decimal(0) + } + } + + def fromDecimal(dec: Decimal, dataType: DataType): Any = { + dataType match { + case BooleanType => dec.toLong == 1 + case DateType => dec.toInt + case TimestampType => dec.toLong + case ByteType => dec.toByte + case ShortType => dec.toShort + case IntegerType => dec.toInt + case LongType => dec.toLong + case FloatType => dec.toFloat + case DoubleType => dec.toDouble + case _: DecimalType => dec + } + } + } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala index 7bd8e6511232..4b6b3b14d9ac 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala @@ -25,7 +25,6 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral} import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Filter, LeafNode, Statistics} -import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -301,30 +300,6 @@ case class FilterEstimation(plan: Filter, catalystConf: SQLConf) extends Logging } } - /** - * For a SQL data type, its internal data type may be different from its external type. - * For DateType, its internal type is Int, and its external data type is Java Date type. - * The min/max values in ColumnStat are saved in their corresponding external type. - * - * @param attrDataType the column data type - * @param litValue the literal value - * @return a BigDecimal value - */ - def convertBoundValue(attrDataType: DataType, litValue: Any): Option[Any] = { - attrDataType match { - case DateType => - Some(DateTimeUtils.toJavaDate(litValue.toString.toInt)) - case TimestampType => - Some(DateTimeUtils.toJavaTimestamp(litValue.toString.toLong)) - case _: DecimalType => - Some(litValue.asInstanceOf[Decimal].toJavaBigDecimal) - case StringType | BinaryType => - None - case _ => - Some(litValue) - } - } - /** * Returns a percentage of rows meeting an equality (=) expression. * This method evaluates the equality predicate for all data types. @@ -356,12 +331,16 @@ case class FilterEstimation(plan: Filter, catalystConf: SQLConf) extends Logging val statsRange = Range(colStat.min, colStat.max, attr.dataType) if (statsRange.contains(literal)) { if (update) { - // We update ColumnStat structure after apply this equality predicate. - // Set distinctCount to 1. Set nullCount to 0. - // Need to save new min/max using the external type value of the literal - val newValue = convertBoundValue(attr.dataType, literal.value) - val newStats = colStat.copy(distinctCount = 1, min = newValue, - max = newValue, nullCount = 0) + // We update ColumnStat structure after apply this equality predicate: + // Set distinctCount to 1, nullCount to 0, and min/max values (if exist) to the literal + // value. + val newStats = attr.dataType match { + case StringType | BinaryType => + colStat.copy(distinctCount = 1, nullCount = 0) + case _ => + colStat.copy(distinctCount = 1, min = Some(literal.value), + max = Some(literal.value), nullCount = 0) + } colStatsMap(attr) = newStats } @@ -430,18 +409,14 @@ case class FilterEstimation(plan: Filter, catalystConf: SQLConf) extends Logging return Some(0.0) } - // Need to save new min/max using the external type value of the literal - val newMax = convertBoundValue( - attr.dataType, validQuerySet.maxBy(v => BigDecimal(v.toString))) - val newMin = convertBoundValue( - attr.dataType, validQuerySet.minBy(v => BigDecimal(v.toString))) - + val newMax = validQuerySet.maxBy(EstimationUtils.toDecimal(_, dataType)) + val newMin = validQuerySet.minBy(EstimationUtils.toDecimal(_, dataType)) // newNdv should not be greater than the old ndv. For example, column has only 2 values // 1 and 6. The predicate column IN (1, 2, 3, 4, 5). validQuerySet.size is 5. newNdv = ndv.min(BigInt(validQuerySet.size)) if (update) { - val newStats = colStat.copy(distinctCount = newNdv, min = newMin, - max = newMax, nullCount = 0) + val newStats = colStat.copy(distinctCount = newNdv, min = Some(newMin), + max = Some(newMax), nullCount = 0) colStatsMap(attr) = newStats } @@ -478,8 +453,8 @@ case class FilterEstimation(plan: Filter, catalystConf: SQLConf) extends Logging val colStat = colStatsMap(attr) val statsRange = Range(colStat.min, colStat.max, attr.dataType).asInstanceOf[NumericRange] - val max = BigDecimal(statsRange.max) - val min = BigDecimal(statsRange.min) + val max = statsRange.max.toBigDecimal + val min = statsRange.min.toBigDecimal val ndv = BigDecimal(colStat.distinctCount) // determine the overlapping degree between predicate range and column's range @@ -540,8 +515,7 @@ case class FilterEstimation(plan: Filter, catalystConf: SQLConf) extends Logging } if (update) { - // Need to save new min/max using the external type value of the literal - val newValue = convertBoundValue(attr.dataType, literal.value) + val newValue = Some(literal.value) var newMax = colStat.max var newMin = colStat.min var newNdv = (ndv * percent).setScale(0, RoundingMode.HALF_UP).toBigInt() @@ -606,14 +580,14 @@ case class FilterEstimation(plan: Filter, catalystConf: SQLConf) extends Logging val colStatLeft = colStatsMap(attrLeft) val statsRangeLeft = Range(colStatLeft.min, colStatLeft.max, attrLeft.dataType) .asInstanceOf[NumericRange] - val maxLeft = BigDecimal(statsRangeLeft.max) - val minLeft = BigDecimal(statsRangeLeft.min) + val maxLeft = statsRangeLeft.max + val minLeft = statsRangeLeft.min val colStatRight = colStatsMap(attrRight) val statsRangeRight = Range(colStatRight.min, colStatRight.max, attrRight.dataType) .asInstanceOf[NumericRange] - val maxRight = BigDecimal(statsRangeRight.max) - val minRight = BigDecimal(statsRangeRight.min) + val maxRight = statsRangeRight.max + val minRight = statsRangeRight.min // determine the overlapping degree between predicate range and column's range val allNotNull = (colStatLeft.nullCount == 0) && (colStatRight.nullCount == 0) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/Range.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/Range.scala index 3d13967cb62a..4ac5ba5689f8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/Range.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/Range.scala @@ -17,12 +17,8 @@ package org.apache.spark.sql.catalyst.plans.logical.statsEstimation -import java.math.{BigDecimal => JDecimal} -import java.sql.{Date, Timestamp} - import org.apache.spark.sql.catalyst.expressions.Literal -import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.types.{BooleanType, DateType, TimestampType, _} +import org.apache.spark.sql.types._ /** Value range of a column. */ @@ -31,13 +27,10 @@ trait Range { } /** For simplicity we use decimal to unify operations of numeric ranges. */ -case class NumericRange(min: JDecimal, max: JDecimal) extends Range { +case class NumericRange(min: Decimal, max: Decimal) extends Range { override def contains(l: Literal): Boolean = { - val decimal = l.dataType match { - case BooleanType => if (l.value.asInstanceOf[Boolean]) new JDecimal(1) else new JDecimal(0) - case _ => new JDecimal(l.value.toString) - } - min.compareTo(decimal) <= 0 && max.compareTo(decimal) >= 0 + val lit = EstimationUtils.toDecimal(l.value, l.dataType) + min <= lit && max >= lit } } @@ -58,7 +51,10 @@ object Range { def apply(min: Option[Any], max: Option[Any], dataType: DataType): Range = dataType match { case StringType | BinaryType => new DefaultRange() case _ if min.isEmpty || max.isEmpty => new NullRange() - case _ => toNumericRange(min.get, max.get, dataType) + case _ => + NumericRange( + min = EstimationUtils.toDecimal(min.get, dataType), + max = EstimationUtils.toDecimal(max.get, dataType)) } def isIntersected(r1: Range, r2: Range): Boolean = (r1, r2) match { @@ -82,51 +78,11 @@ object Range { // binary/string types don't support intersecting. (None, None) case (n1: NumericRange, n2: NumericRange) => - val newRange = NumericRange(n1.min.max(n2.min), n1.max.min(n2.max)) - val (newMin, newMax) = fromNumericRange(newRange, dt) - (Some(newMin), Some(newMax)) + // Choose the maximum of two min values, and the minimum of two max values. + val newMin = if (n1.min <= n2.min) n2.min else n1.min + val newMax = if (n1.max <= n2.max) n1.max else n2.max + (Some(EstimationUtils.fromDecimal(newMin, dt)), + Some(EstimationUtils.fromDecimal(newMax, dt))) } } - - /** - * For simplicity we use decimal to unify operations of numeric types, the two methods below - * are the contract of conversion. - */ - private def toNumericRange(min: Any, max: Any, dataType: DataType): NumericRange = { - dataType match { - case _: NumericType => - NumericRange(new JDecimal(min.toString), new JDecimal(max.toString)) - case BooleanType => - val min1 = if (min.asInstanceOf[Boolean]) 1 else 0 - val max1 = if (max.asInstanceOf[Boolean]) 1 else 0 - NumericRange(new JDecimal(min1), new JDecimal(max1)) - case DateType => - val min1 = DateTimeUtils.fromJavaDate(min.asInstanceOf[Date]) - val max1 = DateTimeUtils.fromJavaDate(max.asInstanceOf[Date]) - NumericRange(new JDecimal(min1), new JDecimal(max1)) - case TimestampType => - val min1 = DateTimeUtils.fromJavaTimestamp(min.asInstanceOf[Timestamp]) - val max1 = DateTimeUtils.fromJavaTimestamp(max.asInstanceOf[Timestamp]) - NumericRange(new JDecimal(min1), new JDecimal(max1)) - } - } - - private def fromNumericRange(n: NumericRange, dataType: DataType): (Any, Any) = { - dataType match { - case _: IntegralType => - (n.min.longValue(), n.max.longValue()) - case FloatType | DoubleType => - (n.min.doubleValue(), n.max.doubleValue()) - case _: DecimalType => - (n.min, n.max) - case BooleanType => - (n.min.longValue() == 1, n.max.longValue() == 1) - case DateType => - (DateTimeUtils.toJavaDate(n.min.intValue()), DateTimeUtils.toJavaDate(n.max.intValue())) - case TimestampType => - (DateTimeUtils.toJavaTimestamp(n.min.longValue()), - DateTimeUtils.toJavaTimestamp(n.max.longValue())) - } - } - } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala index cffb0d873928..a28447840ae0 100755 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLite import org.apache.spark.sql.catalyst.plans.LeftOuter import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Filter, Join, Statistics} import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils._ +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ /** @@ -45,15 +46,15 @@ class FilterEstimationSuite extends StatsEstimationTestBase { nullCount = 0, avgLen = 1, maxLen = 1) // column cdate has 10 values from 2017-01-01 through 2017-01-10. - val dMin = Date.valueOf("2017-01-01") - val dMax = Date.valueOf("2017-01-10") + val dMin = DateTimeUtils.fromJavaDate(Date.valueOf("2017-01-01")) + val dMax = DateTimeUtils.fromJavaDate(Date.valueOf("2017-01-10")) val attrDate = AttributeReference("cdate", DateType)() val colStatDate = ColumnStat(distinctCount = 10, min = Some(dMin), max = Some(dMax), nullCount = 0, avgLen = 4, maxLen = 4) // column cdecimal has 4 values from 0.20 through 0.80 at increment of 0.20. - val decMin = new java.math.BigDecimal("0.200000000000000000") - val decMax = new java.math.BigDecimal("0.800000000000000000") + val decMin = Decimal("0.200000000000000000") + val decMax = Decimal("0.800000000000000000") val attrDecimal = AttributeReference("cdecimal", DecimalType(18, 18))() val colStatDecimal = ColumnStat(distinctCount = 4, min = Some(decMin), max = Some(decMax), nullCount = 0, avgLen = 8, maxLen = 8) @@ -147,7 +148,6 @@ class FilterEstimationSuite extends StatsEstimationTestBase { test("cint < 3 OR null") { val condition = Or(LessThan(attrInt, Literal(3)), Literal(null, IntegerType)) - val m = Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)).stats(conf) validateEstimatedStats( Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)), Seq(attrInt -> colStatInt), @@ -341,6 +341,14 @@ class FilterEstimationSuite extends StatsEstimationTestBase { expectedRowCount = 7) } + test("cbool IN (true)") { + validateEstimatedStats( + Filter(InSet(attrBool, Set(true)), childStatsTestPlan(Seq(attrBool), 10L)), + Seq(attrBool -> ColumnStat(distinctCount = 1, min = Some(true), max = Some(true), + nullCount = 0, avgLen = 1, maxLen = 1)), + expectedRowCount = 5) + } + test("cbool = true") { validateEstimatedStats( Filter(EqualTo(attrBool, Literal(true)), childStatsTestPlan(Seq(attrBool), 10L)), @@ -358,9 +366,9 @@ class FilterEstimationSuite extends StatsEstimationTestBase { } test("cdate = cast('2017-01-02' AS DATE)") { - val d20170102 = Date.valueOf("2017-01-02") + val d20170102 = DateTimeUtils.fromJavaDate(Date.valueOf("2017-01-02")) validateEstimatedStats( - Filter(EqualTo(attrDate, Literal(d20170102)), + Filter(EqualTo(attrDate, Literal(d20170102, DateType)), childStatsTestPlan(Seq(attrDate), 10L)), Seq(attrDate -> ColumnStat(distinctCount = 1, min = Some(d20170102), max = Some(d20170102), nullCount = 0, avgLen = 4, maxLen = 4)), @@ -368,9 +376,9 @@ class FilterEstimationSuite extends StatsEstimationTestBase { } test("cdate < cast('2017-01-03' AS DATE)") { - val d20170103 = Date.valueOf("2017-01-03") + val d20170103 = DateTimeUtils.fromJavaDate(Date.valueOf("2017-01-03")) validateEstimatedStats( - Filter(LessThan(attrDate, Literal(d20170103)), + Filter(LessThan(attrDate, Literal(d20170103, DateType)), childStatsTestPlan(Seq(attrDate), 10L)), Seq(attrDate -> ColumnStat(distinctCount = 2, min = Some(dMin), max = Some(d20170103), nullCount = 0, avgLen = 4, maxLen = 4)), @@ -379,19 +387,19 @@ class FilterEstimationSuite extends StatsEstimationTestBase { test("""cdate IN ( cast('2017-01-03' AS DATE), cast('2017-01-04' AS DATE), cast('2017-01-05' AS DATE) )""") { - val d20170103 = Date.valueOf("2017-01-03") - val d20170104 = Date.valueOf("2017-01-04") - val d20170105 = Date.valueOf("2017-01-05") + val d20170103 = DateTimeUtils.fromJavaDate(Date.valueOf("2017-01-03")) + val d20170104 = DateTimeUtils.fromJavaDate(Date.valueOf("2017-01-04")) + val d20170105 = DateTimeUtils.fromJavaDate(Date.valueOf("2017-01-05")) validateEstimatedStats( - Filter(In(attrDate, Seq(Literal(d20170103), Literal(d20170104), Literal(d20170105))), - childStatsTestPlan(Seq(attrDate), 10L)), + Filter(In(attrDate, Seq(Literal(d20170103, DateType), Literal(d20170104, DateType), + Literal(d20170105, DateType))), childStatsTestPlan(Seq(attrDate), 10L)), Seq(attrDate -> ColumnStat(distinctCount = 3, min = Some(d20170103), max = Some(d20170105), nullCount = 0, avgLen = 4, maxLen = 4)), expectedRowCount = 3) } test("cdecimal = 0.400000000000000000") { - val dec_0_40 = new java.math.BigDecimal("0.400000000000000000") + val dec_0_40 = Decimal("0.400000000000000000") validateEstimatedStats( Filter(EqualTo(attrDecimal, Literal(dec_0_40)), childStatsTestPlan(Seq(attrDecimal), 4L)), @@ -401,7 +409,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { } test("cdecimal < 0.60 ") { - val dec_0_60 = new java.math.BigDecimal("0.600000000000000000") + val dec_0_60 = Decimal("0.600000000000000000") validateEstimatedStats( Filter(LessThan(attrDecimal, Literal(dec_0_60)), childStatsTestPlan(Seq(attrDecimal), 4L)), @@ -532,7 +540,6 @@ class FilterEstimationSuite extends StatsEstimationTestBase { test("cint = cint3") { // no records qualify due to no overlap - val emptyColStats = Seq[(Attribute, ColumnStat)]() validateEstimatedStats( Filter(EqualTo(attrInt, attrInt3), childStatsTestPlan(Seq(attrInt, attrInt3), 10L)), Nil, // set to empty diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala index f62df842fa50..2d6b6e8e21f3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeMap, import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Join, Project, Statistics} import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils._ +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types.{DateType, TimestampType, _} @@ -254,24 +255,24 @@ class JoinEstimationSuite extends StatsEstimationTestBase { test("test join keys of different types") { /** Columns in a table with only one row */ def genColumnData: mutable.LinkedHashMap[Attribute, ColumnStat] = { - val dec = new java.math.BigDecimal("1.000000000000000000") - val date = Date.valueOf("2016-05-08") - val timestamp = Timestamp.valueOf("2016-05-08 00:00:01") + val dec = Decimal("1.000000000000000000") + val date = DateTimeUtils.fromJavaDate(Date.valueOf("2016-05-08")) + val timestamp = DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("2016-05-08 00:00:01")) mutable.LinkedHashMap[Attribute, ColumnStat]( AttributeReference("cbool", BooleanType)() -> ColumnStat(distinctCount = 1, min = Some(false), max = Some(false), nullCount = 0, avgLen = 1, maxLen = 1), AttributeReference("cbyte", ByteType)() -> ColumnStat(distinctCount = 1, - min = Some(1L), max = Some(1L), nullCount = 0, avgLen = 1, maxLen = 1), + min = Some(1.toByte), max = Some(1.toByte), nullCount = 0, avgLen = 1, maxLen = 1), AttributeReference("cshort", ShortType)() -> ColumnStat(distinctCount = 1, - min = Some(1L), max = Some(1L), nullCount = 0, avgLen = 2, maxLen = 2), + min = Some(1.toShort), max = Some(1.toShort), nullCount = 0, avgLen = 2, maxLen = 2), AttributeReference("cint", IntegerType)() -> ColumnStat(distinctCount = 1, - min = Some(1L), max = Some(1L), nullCount = 0, avgLen = 4, maxLen = 4), + min = Some(1), max = Some(1), nullCount = 0, avgLen = 4, maxLen = 4), AttributeReference("clong", LongType)() -> ColumnStat(distinctCount = 1, min = Some(1L), max = Some(1L), nullCount = 0, avgLen = 8, maxLen = 8), AttributeReference("cdouble", DoubleType)() -> ColumnStat(distinctCount = 1, min = Some(1.0), max = Some(1.0), nullCount = 0, avgLen = 8, maxLen = 8), AttributeReference("cfloat", FloatType)() -> ColumnStat(distinctCount = 1, - min = Some(1.0), max = Some(1.0), nullCount = 0, avgLen = 4, maxLen = 4), + min = Some(1.0f), max = Some(1.0f), nullCount = 0, avgLen = 4, maxLen = 4), AttributeReference("cdec", DecimalType.SYSTEM_DEFAULT)() -> ColumnStat(distinctCount = 1, min = Some(dec), max = Some(dec), nullCount = 0, avgLen = 16, maxLen = 16), AttributeReference("cstring", StringType)() -> ColumnStat(distinctCount = 1, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/ProjectEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/ProjectEstimationSuite.scala index f408dc415358..a5c4d22a2938 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/ProjectEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/ProjectEstimationSuite.scala @@ -21,6 +21,7 @@ import java.sql.{Date, Timestamp} import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeMap, AttributeReference} import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ @@ -62,28 +63,28 @@ class ProjectEstimationSuite extends StatsEstimationTestBase { } test("test row size estimation") { - val dec1 = new java.math.BigDecimal("1.000000000000000000") - val dec2 = new java.math.BigDecimal("8.000000000000000000") - val d1 = Date.valueOf("2016-05-08") - val d2 = Date.valueOf("2016-05-09") - val t1 = Timestamp.valueOf("2016-05-08 00:00:01") - val t2 = Timestamp.valueOf("2016-05-09 00:00:02") + val dec1 = Decimal("1.000000000000000000") + val dec2 = Decimal("8.000000000000000000") + val d1 = DateTimeUtils.fromJavaDate(Date.valueOf("2016-05-08")) + val d2 = DateTimeUtils.fromJavaDate(Date.valueOf("2016-05-09")) + val t1 = DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("2016-05-08 00:00:01")) + val t2 = DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("2016-05-09 00:00:02")) val columnInfo: AttributeMap[ColumnStat] = AttributeMap(Seq( AttributeReference("cbool", BooleanType)() -> ColumnStat(distinctCount = 2, min = Some(false), max = Some(true), nullCount = 0, avgLen = 1, maxLen = 1), AttributeReference("cbyte", ByteType)() -> ColumnStat(distinctCount = 2, - min = Some(1L), max = Some(2L), nullCount = 0, avgLen = 1, maxLen = 1), + min = Some(1.toByte), max = Some(2.toByte), nullCount = 0, avgLen = 1, maxLen = 1), AttributeReference("cshort", ShortType)() -> ColumnStat(distinctCount = 2, - min = Some(1L), max = Some(3L), nullCount = 0, avgLen = 2, maxLen = 2), + min = Some(1.toShort), max = Some(3.toShort), nullCount = 0, avgLen = 2, maxLen = 2), AttributeReference("cint", IntegerType)() -> ColumnStat(distinctCount = 2, - min = Some(1L), max = Some(4L), nullCount = 0, avgLen = 4, maxLen = 4), + min = Some(1), max = Some(4), nullCount = 0, avgLen = 4, maxLen = 4), AttributeReference("clong", LongType)() -> ColumnStat(distinctCount = 2, min = Some(1L), max = Some(5L), nullCount = 0, avgLen = 8, maxLen = 8), AttributeReference("cdouble", DoubleType)() -> ColumnStat(distinctCount = 2, min = Some(1.0), max = Some(6.0), nullCount = 0, avgLen = 8, maxLen = 8), AttributeReference("cfloat", FloatType)() -> ColumnStat(distinctCount = 2, - min = Some(1.0), max = Some(7.0), nullCount = 0, avgLen = 4, maxLen = 4), + min = Some(1.0f), max = Some(7.0f), nullCount = 0, avgLen = 4, maxLen = 4), AttributeReference("cdecimal", DecimalType.SYSTEM_DEFAULT)() -> ColumnStat(distinctCount = 2, min = Some(dec1), max = Some(dec2), nullCount = 0, avgLen = 16, maxLen = 16), AttributeReference("cstring", StringType)() -> ColumnStat(distinctCount = 2, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala index b89014ed8ef5..0d8db2ff5d5a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala @@ -73,10 +73,10 @@ case class AnalyzeColumnCommand( val relation = sparkSession.table(tableIdent).logicalPlan // Resolve the column names and dedup using AttributeSet val resolver = sparkSession.sessionState.conf.resolver - val attributesToAnalyze = AttributeSet(columnNames.map { col => + val attributesToAnalyze = columnNames.map { col => val exprOption = relation.output.find(attr => resolver(attr.name, col)) exprOption.getOrElse(throw new AnalysisException(s"Column $col does not exist.")) - }).toSeq + } // Make sure the column types are supported for stats gathering. attributesToAnalyze.foreach { attr => @@ -99,8 +99,8 @@ case class AnalyzeColumnCommand( val statsRow = Dataset.ofRows(sparkSession, Aggregate(Nil, namedExpressions, relation)).head() val rowCount = statsRow.getLong(0) - val columnStats = attributesToAnalyze.zipWithIndex.map { case (expr, i) => - (expr.name, ColumnStat.rowToColumnStat(statsRow.getStruct(i + 1))) + val columnStats = attributesToAnalyze.zipWithIndex.map { case (attr, i) => + (attr.name, ColumnStat.rowToColumnStat(statsRow.getStruct(i + 1), attr)) }.toMap (rowCount, columnStats) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala index 1f547c5a2a8f..ddc393c8da05 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala @@ -26,6 +26,7 @@ import scala.util.Random import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog.{CatalogRelation, CatalogStatistics} import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.internal.StaticSQLConf import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} @@ -117,7 +118,7 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared val df = data.toDF(stats.keys.toSeq :+ "carray" : _*) stats.zip(df.schema).foreach { case ((k, v), field) => withClue(s"column $k with type ${field.dataType}") { - val roundtrip = ColumnStat.fromMap("table_is_foo", field, v.toMap) + val roundtrip = ColumnStat.fromMap("table_is_foo", field, v.toMap(k, field.dataType)) assert(roundtrip == Some(v)) } } @@ -201,17 +202,19 @@ abstract class StatisticsCollectionTestBase extends QueryTest with SQLTestUtils /** A mapping from column to the stats collected. */ protected val stats = mutable.LinkedHashMap( "cbool" -> ColumnStat(2, Some(false), Some(true), 1, 1, 1), - "cbyte" -> ColumnStat(2, Some(1L), Some(2L), 1, 1, 1), - "cshort" -> ColumnStat(2, Some(1L), Some(3L), 1, 2, 2), - "cint" -> ColumnStat(2, Some(1L), Some(4L), 1, 4, 4), + "cbyte" -> ColumnStat(2, Some(1.toByte), Some(2.toByte), 1, 1, 1), + "cshort" -> ColumnStat(2, Some(1.toShort), Some(3.toShort), 1, 2, 2), + "cint" -> ColumnStat(2, Some(1), Some(4), 1, 4, 4), "clong" -> ColumnStat(2, Some(1L), Some(5L), 1, 8, 8), "cdouble" -> ColumnStat(2, Some(1.0), Some(6.0), 1, 8, 8), - "cfloat" -> ColumnStat(2, Some(1.0), Some(7.0), 1, 4, 4), - "cdecimal" -> ColumnStat(2, Some(dec1), Some(dec2), 1, 16, 16), + "cfloat" -> ColumnStat(2, Some(1.0f), Some(7.0f), 1, 4, 4), + "cdecimal" -> ColumnStat(2, Some(Decimal(dec1)), Some(Decimal(dec2)), 1, 16, 16), "cstring" -> ColumnStat(2, None, None, 1, 3, 3), "cbinary" -> ColumnStat(2, None, None, 1, 3, 3), - "cdate" -> ColumnStat(2, Some(d1), Some(d2), 1, 4, 4), - "ctimestamp" -> ColumnStat(2, Some(t1), Some(t2), 1, 8, 8) + "cdate" -> ColumnStat(2, Some(DateTimeUtils.fromJavaDate(d1)), + Some(DateTimeUtils.fromJavaDate(d2)), 1, 4, 4), + "ctimestamp" -> ColumnStat(2, Some(DateTimeUtils.fromJavaTimestamp(t1)), + Some(DateTimeUtils.fromJavaTimestamp(t2)), 1, 8, 8) ) private val randomName = new Random(31) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index 806f2be5faeb..8b0fdf49cefa 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -526,8 +526,10 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat if (stats.rowCount.isDefined) { statsProperties += STATISTICS_NUM_ROWS -> stats.rowCount.get.toString() } + val colNameTypeMap: Map[String, DataType] = + tableDefinition.schema.fields.map(f => (f.name, f.dataType)).toMap stats.colStats.foreach { case (colName, colStat) => - colStat.toMap.foreach { case (k, v) => + colStat.toMap(colName, colNameTypeMap(colName)).foreach { case (k, v) => statsProperties += (columnStatKeyPropName(colName, k) -> v) } } From 98b41ecbcbcddacdc2801c38fccc9823e710783b Mon Sep 17 00:00:00 2001 From: ouyangxiaochen Date: Sat, 15 Apr 2017 10:34:57 +0100 Subject: [PATCH 284/512] [SPARK-20316][SQL] Val and Var should strictly follow the Scala syntax ## What changes were proposed in this pull request? val and var should strictly follow the Scala syntax ## How was this patch tested? manual test and exisiting test cases Author: ouyangxiaochen Closes #17628 from ouyangxiaochen/spark-413. --- .../spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala index d5cc3b385504..33e18a8da60f 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala @@ -47,8 +47,8 @@ import org.apache.spark.util.ShutdownHookManager * has dropped its support. */ private[hive] object SparkSQLCLIDriver extends Logging { - private var prompt = "spark-sql" - private var continuedPrompt = "".padTo(prompt.length, ' ') + private val prompt = "spark-sql" + private val continuedPrompt = "".padTo(prompt.length, ' ') private var transport: TSocket = _ installSignalHandler() From 35e5ae4f81176af52569c465520a703529893b50 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Sun, 16 Apr 2017 11:14:18 +0800 Subject: [PATCH 285/512] [SPARK-19716][SQL][FOLLOW-UP] UnresolvedMapObjects should always be serializable ## What changes were proposed in this pull request? In https://github.com/apache/spark/pull/17398 we introduced `UnresolvedMapObjects` as a placeholder of `MapObjects`. Unfortunately `UnresolvedMapObjects` is not serializable as its `function` may reference Scala `Type` which is not serializable. Ideally this is fine, as we will never serialize and send unresolved expressions to executors. However users may accidentally do this, e.g. mistakenly reference an encoder instance when implementing `Aggregator`, we should fix it so that it's just a performance issue(more network traffic) and should not fail the query. ## How was this patch tested? N/A Author: Wenchen Fan Closes #17639 from cloud-fan/minor. --- .../expressions/objects/objects.scala | 56 ++++++++++--------- 1 file changed, 31 insertions(+), 25 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index eed773d4cb36..f446c3e4a75f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -406,7 +406,7 @@ case class WrapOption(child: Expression, optType: DataType) } /** - * A place holder for the loop variable used in [[MapObjects]]. This should never be constructed + * A placeholder for the loop variable used in [[MapObjects]]. This should never be constructed * manually, but will instead be passed into the provided lambda function. */ case class LambdaVariable( @@ -421,6 +421,27 @@ case class LambdaVariable( } } +/** + * When constructing [[MapObjects]], the element type must be given, which may not be available + * before analysis. This class acts like a placeholder for [[MapObjects]], and will be replaced by + * [[MapObjects]] during analysis after the input data is resolved. + * Note that, ideally we should not serialize and send unresolved expressions to executors, but + * users may accidentally do this(e.g. mistakenly reference an encoder instance when implementing + * Aggregator). Here we mark `function` as transient because it may reference scala Type, which is + * not serializable. Then even users mistakenly reference unresolved expression and serialize it, + * it's just a performance issue(more network traffic), and will not fail. + */ +case class UnresolvedMapObjects( + @transient function: Expression => Expression, + child: Expression, + customCollectionCls: Option[Class[_]] = None) extends UnaryExpression with Unevaluable { + override lazy val resolved = false + + override def dataType: DataType = customCollectionCls.map(ObjectType.apply).getOrElse { + throw new UnsupportedOperationException("not resolved") + } +} + object MapObjects { private val curId = new java.util.concurrent.atomic.AtomicInteger() @@ -442,20 +463,8 @@ object MapObjects { val loopValue = s"MapObjects_loopValue$id" val loopIsNull = s"MapObjects_loopIsNull$id" val loopVar = LambdaVariable(loopValue, loopIsNull, elementType) - val builderValue = s"MapObjects_builderValue$id" - MapObjects(loopValue, loopIsNull, elementType, function(loopVar), inputData, - customCollectionCls, builderValue) - } -} - -case class UnresolvedMapObjects( - function: Expression => Expression, - child: Expression, - customCollectionCls: Option[Class[_]] = None) extends UnaryExpression with Unevaluable { - override lazy val resolved = false - - override def dataType: DataType = customCollectionCls.map(ObjectType.apply).getOrElse { - throw new UnsupportedOperationException("not resolved") + MapObjects( + loopValue, loopIsNull, elementType, function(loopVar), inputData, customCollectionCls) } } @@ -482,8 +491,6 @@ case class UnresolvedMapObjects( * @param inputData An expression that when evaluated returns a collection object. * @param customCollectionCls Class of the resulting collection (returning ObjectType) * or None (returning ArrayType) - * @param builderValue The name of the builder variable used to construct the resulting collection - * (used only when returning ObjectType) */ case class MapObjects private( loopValue: String, @@ -491,8 +498,7 @@ case class MapObjects private( loopVarDataType: DataType, lambdaFunction: Expression, inputData: Expression, - customCollectionCls: Option[Class[_]], - builderValue: String) extends Expression with NonSQLExpression { + customCollectionCls: Option[Class[_]]) extends Expression with NonSQLExpression { override def nullable: Boolean = inputData.nullable @@ -590,15 +596,15 @@ case class MapObjects private( customCollectionCls match { case Some(cls) => // collection - val collObjectName = s"${cls.getName}$$.MODULE$$" - val getBuilderVar = s"$collObjectName.newBuilder()" + val getBuilder = s"${cls.getName}$$.MODULE$$.newBuilder()" + val builder = ctx.freshName("collectionBuilder") ( s""" - ${classOf[Builder[_, _]].getName} $builderValue = $getBuilderVar; - $builderValue.sizeHint($dataLength); + ${classOf[Builder[_, _]].getName} $builder = $getBuilder; + $builder.sizeHint($dataLength); """, - genValue => s"$builderValue.$$plus$$eq($genValue);", - s"(${cls.getName}) $builderValue.result();" + genValue => s"$builder.$$plus$$eq($genValue);", + s"(${cls.getName}) $builder.result();" ) case None => // array From e090f3c0ceebdf341536a1c0c70c80afddf2ee2a Mon Sep 17 00:00:00 2001 From: Xiao Li Date: Sun, 16 Apr 2017 12:09:34 +0800 Subject: [PATCH 286/512] [SPARK-20335][SQL] Children expressions of Hive UDF impacts the determinism of Hive UDF ### What changes were proposed in this pull request? ```JAVA /** * Certain optimizations should not be applied if UDF is not deterministic. * Deterministic UDF returns same result each time it is invoked with a * particular input. This determinism just needs to hold within the context of * a query. * * return true if the UDF is deterministic */ boolean deterministic() default true; ``` Based on the definition of [UDFType](https://github.com/apache/hive/blob/master/ql/src/java/org/apache/hadoop/hive/ql/udf/UDFType.java#L42-L50), when Hive UDF's children are non-deterministic, Hive UDF is also non-deterministic. ### How was this patch tested? Added test cases. Author: Xiao Li Closes #17635 from gatorsmile/udfDeterministic. --- .../org/apache/spark/sql/hive/hiveUDFs.scala | 4 ++-- .../hive/execution/AggregationQuerySuite.scala | 13 +++++++++++++ .../sql/hive/execution/HiveUDAFSuite.scala | 18 +++++++++++++++++- .../sql/hive/execution/HiveUDFSuite.scala | 15 +++++++++++++++ 4 files changed, 47 insertions(+), 3 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala index 51c814cf32a8..a83ad61b204a 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala @@ -44,7 +44,7 @@ private[hive] case class HiveSimpleUDF( name: String, funcWrapper: HiveFunctionWrapper, children: Seq[Expression]) extends Expression with HiveInspectors with CodegenFallback with Logging { - override def deterministic: Boolean = isUDFDeterministic + override def deterministic: Boolean = isUDFDeterministic && children.forall(_.deterministic) override def nullable: Boolean = true @@ -123,7 +123,7 @@ private[hive] case class HiveGenericUDF( override def nullable: Boolean = true - override def deterministic: Boolean = isUDFDeterministic + override def deterministic: Boolean = isUDFDeterministic && children.forall(_.deterministic) override def foldable: Boolean = isUDFDeterministic && returnInspector.isInstanceOf[ConstantObjectInspector] diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index 4a8086d7e540..84f915977bd8 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -509,6 +509,19 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te Row(null, null, 110.0, null, null, 10.0) :: Nil) } + test("non-deterministic children expressions of UDAF") { + val e = intercept[AnalysisException] { + spark.sql( + """ + |SELECT mydoublesum(value + 1.5 * key + rand()) + |FROM agg1 + |GROUP BY key + """.stripMargin) + }.getMessage + assert(Seq("nondeterministic expression", + "should not appear in the arguments of an aggregate function").forall(e.contains)) + } + test("interpreted aggregate function") { checkAnswer( spark.sql( diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala index c9ef72ee112c..479ca1e8def5 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.hive.execution import scala.collection.JavaConverters._ +import org.apache.hadoop.hive.ql.udf.UDAFPercentile import org.apache.hadoop.hive.ql.udf.generic.{AbstractGenericUDAFResolver, GenericUDAFEvaluator, GenericUDAFMax} import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.{AggregationBuffer, Mode} import org.apache.hadoop.hive.ql.util.JavaDataModel @@ -26,7 +27,7 @@ import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, ObjectIns import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo -import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.{AnalysisException, QueryTest, Row} import org.apache.spark.sql.execution.aggregate.ObjectHashAggregateExec import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.test.SQLTestUtils @@ -84,6 +85,21 @@ class HiveUDAFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils { Row(1, Row(1, 1)) )) } + + test("non-deterministic children expressions of UDAF") { + withTempView("view1") { + spark.range(1).selectExpr("id as x", "id as y").createTempView("view1") + withUserDefinedFunction("testUDAFPercentile" -> true) { + // non-deterministic children of Hive UDAF + sql(s"CREATE TEMPORARY FUNCTION testUDAFPercentile AS '${classOf[UDAFPercentile].getName}'") + val e1 = intercept[AnalysisException] { + sql("SELECT testUDAFPercentile(x, rand()) from view1 group by y") + }.getMessage + assert(Seq("nondeterministic expression", + "should not appear in the arguments of an aggregate function").forall(e1.contains)) + } + } + } } /** diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala index ef6883839d43..4bbf9259192e 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala @@ -31,6 +31,7 @@ import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectIn import org.apache.hadoop.io.{LongWritable, Writable} import org.apache.spark.sql.{AnalysisException, QueryTest, Row} +import org.apache.spark.sql.catalyst.plans.logical.Project import org.apache.spark.sql.functions.max import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.test.SQLTestUtils @@ -387,6 +388,20 @@ class HiveUDFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils { hiveContext.reset() } + test("non-deterministic children of UDF") { + withUserDefinedFunction("testStringStringUDF" -> true, "testGenericUDFHash" -> true) { + // HiveSimpleUDF + sql(s"CREATE TEMPORARY FUNCTION testStringStringUDF AS '${classOf[UDFStringString].getName}'") + val df1 = sql("SELECT testStringStringUDF(rand(), \"hello\")") + assert(!df1.logicalPlan.asInstanceOf[Project].projectList.forall(_.deterministic)) + + // HiveGenericUDF + sql(s"CREATE TEMPORARY FUNCTION testGenericUDFHash AS '${classOf[GenericUDFHash].getName}'") + val df2 = sql("SELECT testGenericUDFHash(rand())") + assert(!df2.logicalPlan.asInstanceOf[Project].projectList.forall(_.deterministic)) + } + } + test("Hive UDFs with insufficient number of input arguments should trigger an analysis error") { Seq((1, 2)).toDF("a", "b").createOrReplaceTempView("testUDF") From a888fed3099e84c2cf45e9419f684a3658ada19d Mon Sep 17 00:00:00 2001 From: Ji Yan Date: Sun, 16 Apr 2017 14:34:12 +0100 Subject: [PATCH 287/512] [SPARK-19740][MESOS] Add support in Spark to pass arbitrary parameters into docker when running on mesos with docker containerizer ## What changes were proposed in this pull request? Allow passing in arbitrary parameters into docker when launching spark executors on mesos with docker containerizer tnachen ## How was this patch tested? Manually built and tested with passed in parameter Author: Ji Yan Closes #17109 from yanji84/ji/allow_set_docker_user. --- docs/running-on-mesos.md | 10 ++++ .../mesos/MesosSchedulerBackendUtil.scala | 36 +++++++++++-- .../MesosSchedulerBackendUtilSuite.scala | 53 +++++++++++++++++++ 3 files changed, 96 insertions(+), 3 deletions(-) create mode 100644 resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtilSuite.scala diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md index ef01cfe4b92c..314a806edf39 100644 --- a/docs/running-on-mesos.md +++ b/docs/running-on-mesos.md @@ -356,6 +356,16 @@ See the [configuration page](configuration.html) for information on Spark config By default Mesos agents will not pull images they already have cached. + + spark.mesos.executor.docker.parameters + (none) + + Set the list of custom parameters which will be passed into the docker run command when launching the Spark executor on Mesos using the docker containerizer. The format of this property is a comma-separated list of + key/value pairs. Example: + +
    key1=val1,key2=val2,key3=val3
    + + spark.mesos.executor.docker.volumes (none) diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtil.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtil.scala index a2adb228dc29..fbcbc55099ec 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtil.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtil.scala @@ -17,7 +17,7 @@ package org.apache.spark.scheduler.cluster.mesos -import org.apache.mesos.Protos.{ContainerInfo, Image, NetworkInfo, Volume} +import org.apache.mesos.Protos.{ContainerInfo, Image, NetworkInfo, Parameter, Volume} import org.apache.mesos.Protos.ContainerInfo.{DockerInfo, MesosInfo} import org.apache.spark.{SparkConf, SparkException} @@ -99,6 +99,28 @@ private[mesos] object MesosSchedulerBackendUtil extends Logging { .toList } + /** + * Parse a list of docker parameters, each of which + * takes the form key=value + */ + private def parseParamsSpec(params: String): List[Parameter] = { + // split with limit of 2 to avoid parsing error when '=' + // exists in the parameter value + params.split(",").map(_.split("=", 2)).flatMap { spec: Array[String] => + val param: Parameter.Builder = Parameter.newBuilder() + spec match { + case Array(key, value) => + Some(param.setKey(key).setValue(value)) + case spec => + logWarning(s"Unable to parse arbitary parameters: $params. " + + "Expected form: \"key=value(, ...)\"") + None + } + } + .map { _.build() } + .toList + } + def containerInfo(conf: SparkConf): ContainerInfo = { val containerType = if (conf.contains("spark.mesos.executor.docker.image") && conf.get("spark.mesos.containerizer", "docker") == "docker") { @@ -120,8 +142,14 @@ private[mesos] object MesosSchedulerBackendUtil extends Logging { .map(parsePortMappingsSpec) .getOrElse(List.empty) + val params = conf + .getOption("spark.mesos.executor.docker.parameters") + .map(parseParamsSpec) + .getOrElse(List.empty) + if (containerType == ContainerInfo.Type.DOCKER) { - containerInfo.setDocker(dockerInfo(image, forcePullImage, portMaps)) + containerInfo + .setDocker(dockerInfo(image, forcePullImage, portMaps, params)) } else { containerInfo.setMesos(mesosInfo(image, forcePullImage)) } @@ -144,11 +172,13 @@ private[mesos] object MesosSchedulerBackendUtil extends Logging { private def dockerInfo( image: String, forcePullImage: Boolean, - portMaps: List[ContainerInfo.DockerInfo.PortMapping]): DockerInfo = { + portMaps: List[ContainerInfo.DockerInfo.PortMapping], + params: List[Parameter]): DockerInfo = { val dockerBuilder = ContainerInfo.DockerInfo.newBuilder() .setImage(image) .setForcePullImage(forcePullImage) portMaps.foreach(dockerBuilder.addPortMappings(_)) + params.foreach(dockerBuilder.addParameters(_)) dockerBuilder.build } diff --git a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtilSuite.scala b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtilSuite.scala new file mode 100644 index 000000000000..caf9d89fdd20 --- /dev/null +++ b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtilSuite.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.cluster.mesos + +import org.scalatest._ +import org.scalatest.mock.MockitoSugar + +import org.apache.spark.{SparkConf, SparkFunSuite} + +class MesosSchedulerBackendUtilSuite extends SparkFunSuite { + + test("ContainerInfo fails to parse invalid docker parameters") { + val conf = new SparkConf() + conf.set("spark.mesos.executor.docker.parameters", "a,b") + conf.set("spark.mesos.executor.docker.image", "test") + + val containerInfo = MesosSchedulerBackendUtil.containerInfo(conf) + val params = containerInfo.getDocker.getParametersList + + assert(params.size() == 0) + } + + test("ContainerInfo parses docker parameters") { + val conf = new SparkConf() + conf.set("spark.mesos.executor.docker.parameters", "a=1,b=2,c=3") + conf.set("spark.mesos.executor.docker.image", "test") + + val containerInfo = MesosSchedulerBackendUtil.containerInfo(conf) + val params = containerInfo.getDocker.getParametersList + assert(params.size() == 3) + assert(params.get(0).getKey == "a") + assert(params.get(0).getValue == "1") + assert(params.get(1).getKey == "b") + assert(params.get(1).getValue == "2") + assert(params.get(2).getKey == "c") + assert(params.get(2).getValue == "3") + } +} From ad935f526f57a9621c0a5ba082b85414c28282f4 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Sun, 16 Apr 2017 14:36:42 +0100 Subject: [PATCH 288/512] [SPARK-20343][BUILD] Add avro dependency in core POM to resolve build failure in SBT Hadoop 2.6 master on Jenkins ## What changes were proposed in this pull request? This PR proposes to add ``` org.apache.avro avro ``` in core POM to see if it resolves the build failure as below: ``` [error] /home/jenkins/workspace/spark-master-test-sbt-hadoop-2.6/core/src/main/scala/org/apache/spark/serializer/GenericAvroSerializer.scala:123: value createDatumWriter is not a member of org.apache.avro.generic.GenericData [error] writerCache.getOrElseUpdate(schema, GenericData.get.createDatumWriter(schema)) [error] ``` https://amplab.cs.berkeley.edu/jenkins/view/Spark%20QA%20Test%20(Dashboard)/job/spark-master-test-sbt-hadoop-2.6/2770/consoleFull ## How was this patch tested? I tried many ways but I was unable to reproduce this in my local. Sean also tried the way I did but he was also unable to reproduce this. Please refer the comments in https://github.com/apache/spark/pull/17477#issuecomment-294094092 Author: hyukjinkwon Closes #17642 from HyukjinKwon/SPARK-20343. --- core/pom.xml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/core/pom.xml b/core/pom.xml index 97a463abbefd..24ce36deeb16 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -33,6 +33,10 @@ Spark Project Core http://spark.apache.org/ + + org.apache.avro + avro + org.apache.avro avro-mapred From 86d251c58591278a7c88745a1049e7a41db11964 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Sun, 16 Apr 2017 11:27:27 -0700 Subject: [PATCH 289/512] [SPARK-20278][R] Disable 'multiple_dots_linter' lint rule that is against project's code style ## What changes were proposed in this pull request? Currently, multi-dot separated variables in R is not allowed. For example, ```diff setMethod("from_json", signature(x = "Column", schema = "structType"), - function(x, schema, asJsonArray = FALSE, ...) { + function(x, schema, as.json.array = FALSE, ...) { if (asJsonArray) { jschema <- callJStatic("org.apache.spark.sql.types.DataTypes", "createArrayType", ``` produces an error as below: ``` R/functions.R:2462:31: style: Words within variable and function names should be separated by '_' rather than '.'. function(x, schema, as.json.array = FALSE, ...) { ^~~~~~~~~~~~~ ``` This seems against https://google.github.io/styleguide/Rguide.xml#identifiers which says > The preferred form for variable names is all lower case letters and words separated with dots This looks because lintr by default https://github.com/jimhester/lintr follows http://r-pkgs.had.co.nz/style.html as written in the README.md. Few cases seems not following Google's one as "a few tweaks". Per [SPARK-6813](https://issues.apache.org/jira/browse/SPARK-6813), we follow Google's R Style Guide with few exceptions https://google.github.io/styleguide/Rguide.xml. This is also merged into Spark's website - https://github.com/apache/spark-website/pull/43 Also, it looks we have no limit on function name. This rule also looks affecting to the name of functions as written in the README.md. > `multiple_dots_linter`: check that function and variable names are separated by _ rather than .. ## How was this patch tested? Manually tested `./dev/lint-r`with the manual change below in `R/functions.R`: ```diff setMethod("from_json", signature(x = "Column", schema = "structType"), - function(x, schema, asJsonArray = FALSE, ...) { + function(x, schema, as.json.array = FALSE, ...) { if (asJsonArray) { jschema <- callJStatic("org.apache.spark.sql.types.DataTypes", "createArrayType", ``` **Before** ```R R/functions.R:2462:31: style: Words within variable and function names should be separated by '_' rather than '.'. function(x, schema, as.json.array = FALSE, ...) { ^~~~~~~~~~~~~ ``` **After** ``` lintr checks passed. ``` Author: hyukjinkwon Closes #17590 from HyukjinKwon/disable-dot-in-name. --- R/pkg/.lintr | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/pkg/.lintr b/R/pkg/.lintr index 038236fc149e..ae50b28ec616 100644 --- a/R/pkg/.lintr +++ b/R/pkg/.lintr @@ -1,2 +1,2 @@ -linters: with_defaults(line_length_linter(100), camel_case_linter = NULL, open_curly_linter(allow_single_line = TRUE), closed_curly_linter(allow_single_line = TRUE)) +linters: with_defaults(line_length_linter(100), multiple_dots_linter = NULL, camel_case_linter = NULL, open_curly_linter(allow_single_line = TRUE), closed_curly_linter(allow_single_line = TRUE)) exclusions: list("inst/profile/general.R" = 1, "inst/profile/shell.R") From 24f09b39c7b947e52fda952676d5114c2540e732 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Mon, 17 Apr 2017 09:04:24 -0700 Subject: [PATCH 290/512] [SPARK-19828][R][FOLLOWUP] Rename asJsonArray to as.json.array in from_json function in R ## What changes were proposed in this pull request? This was suggested to be `as.json.array` at the first place in the PR to SPARK-19828 but we could not do this as the lint check emits an error for multiple dots in the variable names. After SPARK-20278, now we are able to use `multiple.dots.in.names`. `asJsonArray` in `from_json` function is still able to be changed as 2.2 is not released yet. So, this PR proposes to rename `asJsonArray` to `as.json.array`. ## How was this patch tested? Jenkins tests, local tests with `./R/run-tests.sh` and manual `./dev/lint-r`. Existing tests should cover this. Author: hyukjinkwon Closes #17653 from HyukjinKwon/SPARK-19828-followup. --- R/pkg/R/functions.R | 8 ++++---- R/pkg/inst/tests/testthat/test_sparkSQL.R | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 449476dec533..c311921fb33d 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -2438,12 +2438,12 @@ setMethod("date_format", signature(y = "Column", x = "character"), #' from_json #' #' Parses a column containing a JSON string into a Column of \code{structType} with the specified -#' \code{schema} or array of \code{structType} if \code{asJsonArray} is set to \code{TRUE}. +#' \code{schema} or array of \code{structType} if \code{as.json.array} is set to \code{TRUE}. #' If the string is unparseable, the Column will contains the value NA. #' #' @param x Column containing the JSON string. #' @param schema a structType object to use as the schema to use when parsing the JSON string. -#' @param asJsonArray indicating if input string is JSON array of objects or a single object. +#' @param as.json.array indicating if input string is JSON array of objects or a single object. #' @param ... additional named properties to control how the json is parsed, accepts the same #' options as the JSON data source. #' @@ -2459,8 +2459,8 @@ setMethod("date_format", signature(y = "Column", x = "character"), #'} #' @note from_json since 2.2.0 setMethod("from_json", signature(x = "Column", schema = "structType"), - function(x, schema, asJsonArray = FALSE, ...) { - if (asJsonArray) { + function(x, schema, as.json.array = FALSE, ...) { + if (as.json.array) { jschema <- callJStatic("org.apache.spark.sql.types.DataTypes", "createArrayType", schema$jobj) diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index 3fbb618ddfc3..6a6c9a809ab1 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -1454,7 +1454,7 @@ test_that("column functions", { jsonArr <- "[{\"name\":\"Bob\"}, {\"name\":\"Alice\"}]" df <- as.DataFrame(list(list("people" = jsonArr))) schema <- structType(structField("name", "string")) - arr <- collect(select(df, alias(from_json(df$people, schema, asJsonArray = TRUE), "arrcol"))) + arr <- collect(select(df, alias(from_json(df$people, schema, as.json.array = TRUE), "arrcol"))) expect_equal(ncol(arr), 1) expect_equal(nrow(arr), 1) expect_is(arr[[1]][[1]], "list") From 01ff0350a85b179715946c3bd4f003db7c5e3641 Mon Sep 17 00:00:00 2001 From: Xiao Li Date: Mon, 17 Apr 2017 09:50:20 -0700 Subject: [PATCH 291/512] [SPARK-20349][SQL] ListFunctions returns duplicate functions after using persistent functions ### What changes were proposed in this pull request? The session catalog caches some persistent functions in the `FunctionRegistry`, so there can be duplicates. Our Catalog API `listFunctions` does not handle it. It would be better if `SessionCatalog` API can de-duplciate the records, instead of doing it by each API caller. In `FunctionRegistry`, our functions are identified by the unquoted string. Thus, this PR is try to parse it using our parser interface and then de-duplicate the names. ### How was this patch tested? Added test cases. Author: Xiao Li Closes #17646 from gatorsmile/showFunctions. --- .../sql/catalyst/catalog/SessionCatalog.scala | 21 ++++++++++++++----- .../sql/execution/command/functions.scala | 4 +--- .../sql/hive/execution/HiveUDFSuite.scala | 17 +++++++++++++++ 3 files changed, 34 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index 1417bccf657c..3fbf83f3a38a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -22,6 +22,7 @@ import java.util.Locale import javax.annotation.concurrent.GuardedBy import scala.collection.mutable +import scala.util.{Failure, Success, Try} import com.google.common.cache.{Cache, CacheBuilder} import org.apache.hadoop.conf.Configuration @@ -1202,15 +1203,25 @@ class SessionCatalog( def listFunctions(db: String, pattern: String): Seq[(FunctionIdentifier, String)] = { val dbName = formatDatabaseName(db) requireDbExists(dbName) - val dbFunctions = externalCatalog.listFunctions(dbName, pattern) - .map { f => FunctionIdentifier(f, Some(dbName)) } - val loadedFunctions = StringUtils.filterPattern(functionRegistry.listFunction(), pattern) - .map { f => FunctionIdentifier(f) } + val dbFunctions = externalCatalog.listFunctions(dbName, pattern).map { f => + FunctionIdentifier(f, Some(dbName)) } + val loadedFunctions = + StringUtils.filterPattern(functionRegistry.listFunction(), pattern).map { f => + // In functionRegistry, function names are stored as an unquoted format. + Try(parser.parseFunctionIdentifier(f)) match { + case Success(e) => e + case Failure(_) => + // The names of some built-in functions are not parsable by our parser, e.g., % + FunctionIdentifier(f) + } + } val functions = dbFunctions ++ loadedFunctions + // The session catalog caches some persistent functions in the FunctionRegistry + // so there can be duplicates. functions.map { case f if FunctionRegistry.functionSet.contains(f.funcName) => (f, "SYSTEM") case f => (f, "USER") - } + }.distinct } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala index e0d002936957..545082324f0d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala @@ -207,8 +207,6 @@ case class ShowFunctionsCommand( case (f, "USER") if showUserFunctions => f.unquotedString case (f, "SYSTEM") if showSystemFunctions => f.unquotedString } - // The session catalog caches some persistent functions in the FunctionRegistry - // so there can be duplicates. - functionNames.distinct.sorted.map(Row(_)) + functionNames.sorted.map(Row(_)) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala index 4bbf9259192e..4446af2e75e0 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala @@ -573,6 +573,23 @@ class HiveUDFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils { checkAnswer(testData.selectExpr("statelessUDF() as s").agg(max($"s")), Row(1)) } } + + test("Show persistent functions") { + val testData = spark.sparkContext.parallelize(StringCaseClass("") :: Nil).toDF() + withTempView("inputTable") { + testData.createOrReplaceTempView("inputTable") + withUserDefinedFunction("testUDFToListInt" -> false) { + val numFunc = spark.catalog.listFunctions().count() + sql(s"CREATE FUNCTION testUDFToListInt AS '${classOf[UDFToListInt].getName}'") + assert(spark.catalog.listFunctions().count() == numFunc + 1) + checkAnswer( + sql("SELECT testUDFToListInt(s) FROM inputTable"), + Seq(Row(Seq(1, 2, 3)))) + assert(sql("show functions").count() == numFunc + 1) + assert(spark.catalog.listFunctions().count() == numFunc + 1) + } + } + } } class TestPair(x: Int, y: Int) extends Writable with Serializable { From e5fee3e4f853f906f0b476bb04ee35a15f1ae650 Mon Sep 17 00:00:00 2001 From: Jakob Odersky Date: Mon, 17 Apr 2017 11:17:57 -0700 Subject: [PATCH 292/512] [SPARK-17647][SQL] Fix backslash escaping in 'LIKE' patterns. ## What changes were proposed in this pull request? This patch fixes a bug in the way LIKE patterns are translated to Java regexes. The bug causes any character following an escaped backslash to be escaped, i.e. there is double-escaping. A concrete example is the following pattern:`'%\\%'`. The expected Java regex that this pattern should correspond to (according to the behavior described below) is `'.*\\.*'`, however the current situation leads to `'.*\\%'` instead. --- Update: in light of the discussion that ensued, we should explicitly define the expected behaviour of LIKE expressions, especially in certain edge cases. With the help of gatorsmile, we put together a list of different RDBMS and their variations wrt to certain standard features. | RDBMS\Features | Wildcards | Default escape [1] | Case sensitivity | | --- | --- | --- | --- | | [MS SQL Server](https://msdn.microsoft.com/en-us/library/ms179859.aspx) | _, %, [], [^] | none | no | | [Oracle](https://docs.oracle.com/cd/B12037_01/server.101/b10759/conditions016.htm) | _, % | none | yes | | [DB2 z/OS](http://www.ibm.com/support/knowledgecenter/SSEPEK_11.0.0/sqlref/src/tpc/db2z_likepredicate.html) | _, % | none | yes | | [MySQL](http://dev.mysql.com/doc/refman/5.7/en/string-comparison-functions.html) | _, % | none | no | | [PostreSQL](https://www.postgresql.org/docs/9.0/static/functions-matching.html) | _, % | \ | yes | | [Hive](https://cwiki.apache.org/confluence/display/Hive/LanguageManual+UDF) | _, % | none | yes | | Current Spark | _, % | \ | yes | [1] Default escape character: most systems do not have a default escape character, instead the user can specify one by calling a like expression with an escape argument [A] LIKE [B] ESCAPE [C]. This syntax is currently not supported by Spark, however I would volunteer to implement this feature in a separate ticket. The specifications are often quite terse and certain scenarios are undocumented, so here is a list of scenarios that I am uncertain about and would appreciate any input. Specifically I am looking for feedback on whether or not Spark's current behavior should be changed. 1. [x] Ending a pattern with the escape sequence, e.g. `like 'a\'`. PostreSQL gives an error: 'LIKE pattern must not end with escape character', which I personally find logical. Currently, Spark allows "non-terminated" escapes and simply ignores them as part of the pattern. According to [DB2's documentation](http://www.ibm.com/support/knowledgecenter/SSEPGG_9.7.0/com.ibm.db2.luw.messages.sql.doc/doc/msql00130n.html), ending a pattern in an escape character is invalid. _Proposed new behaviour in Spark: throw AnalysisException_ 2. [x] Empty input, e.g. `'' like ''` Postgres and DB2 will match empty input only if the pattern is empty as well, any other combination of empty input will not match. Spark currently follows this rule. 3. [x] Escape before a non-special character, e.g. `'a' like '\a'`. Escaping a non-wildcard character is not really documented but PostgreSQL just treats it verbatim, which I also find the least surprising behavior. Spark does the same. According to [DB2's documentation](http://www.ibm.com/support/knowledgecenter/SSEPGG_9.7.0/com.ibm.db2.luw.messages.sql.doc/doc/msql00130n.html), it is invalid to follow an escape character with anything other than an escape character, an underscore or a percent sign. _Proposed new behaviour in Spark: throw AnalysisException_ The current specification is also described in the operator's source code in this patch. ## How was this patch tested? Extra case in regex unit tests. Author: Jakob Odersky This patch had conflicts when merged, resolved by Committer: Reynold Xin Closes #15398 from jodersky/SPARK-17647. --- .../expressions/regexpExpressions.scala | 25 ++- .../spark/sql/catalyst/util/StringUtils.scala | 50 +++--- .../expressions/RegexpExpressionsSuite.scala | 161 +++++++++++------- .../sql/catalyst/util/StringUtilsSuite.scala | 4 +- 4 files changed, 153 insertions(+), 87 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index 49b779711308..a36da8e94b3a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -69,7 +69,30 @@ abstract class StringRegexExpression extends BinaryExpression * Simple RegEx pattern matching function */ @ExpressionDescription( - usage = "str _FUNC_ pattern - Returns true if `str` matches `pattern`, or false otherwise.") + usage = "str _FUNC_ pattern - Returns true if str matches pattern, " + + "null if any arguments are null, false otherwise.", + extended = """ + Arguments: + str - a string expression + pattern - a string expression. The pattern is a string which is matched literally, with + exception to the following special symbols: + + _ matches any one character in the input (similar to . in posix regular expressions) + + % matches zero ore more characters in the input (similar to .* in posix regular + expressions) + + The escape character is '\'. If an escape character precedes a special symbol or another + escape character, the following character is matched literally. It is invalid to escape + any other character. + + Examples: + > SELECT '%SystemDrive%\Users\John' _FUNC_ '\%SystemDrive\%\\Users%' + true + + See also: + Use RLIKE to match with standard regular expressions. +""") case class Like(left: Expression, right: Expression) extends StringRegexExpression { override def escape(v: String): String = StringUtils.escapeLikeRegex(v) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala index cde8bd5b9614..ca22ea24207e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala @@ -19,32 +19,44 @@ package org.apache.spark.sql.catalyst.util import java.util.regex.{Pattern, PatternSyntaxException} +import org.apache.spark.sql.AnalysisException import org.apache.spark.unsafe.types.UTF8String object StringUtils { - // replace the _ with .{1} exactly match 1 time of any character - // replace the % with .*, match 0 or more times with any character - def escapeLikeRegex(v: String): String = { - if (!v.isEmpty) { - "(?s)" + (' ' +: v.init).zip(v).flatMap { - case (prev, '\\') => "" - case ('\\', c) => - c match { - case '_' => "_" - case '%' => "%" - case _ => Pattern.quote("\\" + c) - } - case (prev, c) => + /** + * Validate and convert SQL 'like' pattern to a Java regular expression. + * + * Underscores (_) are converted to '.' and percent signs (%) are converted to '.*', other + * characters are quoted literally. Escaping is done according to the rules specified in + * [[org.apache.spark.sql.catalyst.expressions.Like]] usage documentation. An invalid pattern will + * throw an [[AnalysisException]]. + * + * @param pattern the SQL pattern to convert + * @return the equivalent Java regular expression of the pattern + */ + def escapeLikeRegex(pattern: String): String = { + val in = pattern.toIterator + val out = new StringBuilder() + + def fail(message: String) = throw new AnalysisException( + s"the pattern '$pattern' is invalid, $message") + + while (in.hasNext) { + in.next match { + case '\\' if in.hasNext => + val c = in.next c match { - case '_' => "." - case '%' => ".*" - case _ => Pattern.quote(Character.toString(c)) + case '_' | '%' | '\\' => out ++= Pattern.quote(Character.toString(c)) + case _ => fail(s"the escape character is not allowed to precede '$c'") } - }.mkString - } else { - v + case '\\' => fail("it is not allowed to end with the escape character") + case '_' => out ++= "." + case '%' => out ++= ".*" + case c => out ++= Pattern.quote(Character.toString(c)) + } } + "(?s)" + out.result() // (?s) enables dotall mode, causing "." to match new lines } private[this] val trueStrings = Set("t", "true", "y", "yes", "1").map(UTF8String.fromString) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala index 5299549e7b4d..1ce150e09198 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala @@ -18,16 +18,38 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.types.StringType +import org.apache.spark.sql.types.{IntegerType, StringType} /** * Unit tests for regular expression (regexp) related SQL expressions. */ class RegexpExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { - test("LIKE literal Regular Expression") { - checkEvaluation(Literal.create(null, StringType).like("a"), null) + /** + * Check if a given expression evaluates to an expected output, in case the input is + * a literal and in case the input is in the form of a row. + * @tparam A type of input + * @param mkExpr the expression to test for a given input + * @param input value that will be used to create the expression, as literal and in the form + * of a row + * @param expected the expected output of the expression + * @param inputToExpression an implicit conversion from the input type to its corresponding + * sql expression + */ + def checkLiteralRow[A](mkExpr: Expression => Expression, input: A, expected: Any) + (implicit inputToExpression: A => Expression): Unit = { + checkEvaluation(mkExpr(input), expected) // check literal input + + val regex = 'a.string.at(0) + checkEvaluation(mkExpr(regex), expected, create_row(input)) // check row input + } + + test("LIKE Pattern") { + + // null handling + checkLiteralRow(Literal.create(null, StringType).like(_), "a", null) checkEvaluation(Literal.create("a", StringType).like(Literal.create(null, StringType)), null) checkEvaluation(Literal.create(null, StringType).like(Literal.create(null, StringType)), null) checkEvaluation( @@ -39,45 +61,64 @@ class RegexpExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation( Literal.create(null, StringType).like(NonFoldableLiteral.create(null, StringType)), null) - checkEvaluation("abdef" like "abdef", true) - checkEvaluation("a_%b" like "a\\__b", true) - checkEvaluation("addb" like "a_%b", true) - checkEvaluation("addb" like "a\\__b", false) - checkEvaluation("addb" like "a%\\%b", false) - checkEvaluation("a_%b" like "a%\\%b", true) - checkEvaluation("addb" like "a%", true) - checkEvaluation("addb" like "**", false) - checkEvaluation("abc" like "a%", true) - checkEvaluation("abc" like "b%", false) - checkEvaluation("abc" like "bc%", false) - checkEvaluation("a\nb" like "a_b", true) - checkEvaluation("ab" like "a%b", true) - checkEvaluation("a\nb" like "a%b", true) - } + // simple patterns + checkLiteralRow("abdef" like _, "abdef", true) + checkLiteralRow("a_%b" like _, "a\\__b", true) + checkLiteralRow("addb" like _, "a_%b", true) + checkLiteralRow("addb" like _, "a\\__b", false) + checkLiteralRow("addb" like _, "a%\\%b", false) + checkLiteralRow("a_%b" like _, "a%\\%b", true) + checkLiteralRow("addb" like _, "a%", true) + checkLiteralRow("addb" like _, "**", false) + checkLiteralRow("abc" like _, "a%", true) + checkLiteralRow("abc" like _, "b%", false) + checkLiteralRow("abc" like _, "bc%", false) + checkLiteralRow("a\nb" like _, "a_b", true) + checkLiteralRow("ab" like _, "a%b", true) + checkLiteralRow("a\nb" like _, "a%b", true) + + // empty input + checkLiteralRow("" like _, "", true) + checkLiteralRow("a" like _, "", false) + checkLiteralRow("" like _, "a", false) + + // SI-17647 double-escaping backslash + checkLiteralRow("""\\\\""" like _, """%\\%""", true) + checkLiteralRow("""%%""" like _, """%%""", true) + checkLiteralRow("""\__""" like _, """\\\__""", true) + checkLiteralRow("""\\\__""" like _, """%\\%\%""", false) + checkLiteralRow("""_\\\%""" like _, """%\\""", false) + + // unicode + // scalastyle:off nonascii + checkLiteralRow("a\u20ACa" like _, "_\u20AC_", true) + checkLiteralRow("a€a" like _, "_€_", true) + checkLiteralRow("a€a" like _, "_\u20AC_", true) + checkLiteralRow("a\u20ACa" like _, "_€_", true) + // scalastyle:on nonascii + + // invalid escaping + val invalidEscape = intercept[AnalysisException] { + evaluate("""a""" like """\a""") + } + assert(invalidEscape.getMessage.contains("pattern")) + + val endEscape = intercept[AnalysisException] { + evaluate("""a""" like """a\""") + } + assert(endEscape.getMessage.contains("pattern")) + + // case + checkLiteralRow("A" like _, "a%", false) + checkLiteralRow("a" like _, "A%", false) + checkLiteralRow("AaA" like _, "_a_", true) - test("LIKE Non-literal Regular Expression") { - val regEx = 'a.string.at(0) - checkEvaluation("abcd" like regEx, null, create_row(null)) - checkEvaluation("abdef" like regEx, true, create_row("abdef")) - checkEvaluation("a_%b" like regEx, true, create_row("a\\__b")) - checkEvaluation("addb" like regEx, true, create_row("a_%b")) - checkEvaluation("addb" like regEx, false, create_row("a\\__b")) - checkEvaluation("addb" like regEx, false, create_row("a%\\%b")) - checkEvaluation("a_%b" like regEx, true, create_row("a%\\%b")) - checkEvaluation("addb" like regEx, true, create_row("a%")) - checkEvaluation("addb" like regEx, false, create_row("**")) - checkEvaluation("abc" like regEx, true, create_row("a%")) - checkEvaluation("abc" like regEx, false, create_row("b%")) - checkEvaluation("abc" like regEx, false, create_row("bc%")) - checkEvaluation("a\nb" like regEx, true, create_row("a_b")) - checkEvaluation("ab" like regEx, true, create_row("a%b")) - checkEvaluation("a\nb" like regEx, true, create_row("a%b")) - - checkEvaluation(Literal.create(null, StringType) like regEx, null, create_row("bc%")) + // example + checkLiteralRow("""%SystemDrive%\Users\John""" like _, """\%SystemDrive\%\\Users%""", true) } - test("RLIKE literal Regular Expression") { - checkEvaluation(Literal.create(null, StringType) rlike "abdef", null) + test("RLIKE Regular Expression") { + checkLiteralRow(Literal.create(null, StringType) rlike _, "abdef", null) checkEvaluation("abdef" rlike Literal.create(null, StringType), null) checkEvaluation(Literal.create(null, StringType) rlike Literal.create(null, StringType), null) checkEvaluation("abdef" rlike NonFoldableLiteral.create("abdef", StringType), true) @@ -87,42 +128,32 @@ class RegexpExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation( Literal.create(null, StringType) rlike NonFoldableLiteral.create(null, StringType), null) - checkEvaluation("abdef" rlike "abdef", true) - checkEvaluation("abbbbc" rlike "a.*c", true) + checkLiteralRow("abdef" rlike _, "abdef", true) + checkLiteralRow("abbbbc" rlike _, "a.*c", true) - checkEvaluation("fofo" rlike "^fo", true) - checkEvaluation("fo\no" rlike "^fo\no$", true) - checkEvaluation("Bn" rlike "^Ba*n", true) - checkEvaluation("afofo" rlike "fo", true) - checkEvaluation("afofo" rlike "^fo", false) - checkEvaluation("Baan" rlike "^Ba?n", false) - checkEvaluation("axe" rlike "pi|apa", false) - checkEvaluation("pip" rlike "^(pi)*$", false) + checkLiteralRow("fofo" rlike _, "^fo", true) + checkLiteralRow("fo\no" rlike _, "^fo\no$", true) + checkLiteralRow("Bn" rlike _, "^Ba*n", true) + checkLiteralRow("afofo" rlike _, "fo", true) + checkLiteralRow("afofo" rlike _, "^fo", false) + checkLiteralRow("Baan" rlike _, "^Ba?n", false) + checkLiteralRow("axe" rlike _, "pi|apa", false) + checkLiteralRow("pip" rlike _, "^(pi)*$", false) - checkEvaluation("abc" rlike "^ab", true) - checkEvaluation("abc" rlike "^bc", false) - checkEvaluation("abc" rlike "^ab", true) - checkEvaluation("abc" rlike "^bc", false) + checkLiteralRow("abc" rlike _, "^ab", true) + checkLiteralRow("abc" rlike _, "^bc", false) + checkLiteralRow("abc" rlike _, "^ab", true) + checkLiteralRow("abc" rlike _, "^bc", false) intercept[java.util.regex.PatternSyntaxException] { evaluate("abbbbc" rlike "**") } - } - - test("RLIKE Non-literal Regular Expression") { - val regEx = 'a.string.at(0) - checkEvaluation("abdef" rlike regEx, true, create_row("abdef")) - checkEvaluation("abbbbc" rlike regEx, true, create_row("a.*c")) - checkEvaluation("fofo" rlike regEx, true, create_row("^fo")) - checkEvaluation("fo\no" rlike regEx, true, create_row("^fo\no$")) - checkEvaluation("Bn" rlike regEx, true, create_row("^Ba*n")) - intercept[java.util.regex.PatternSyntaxException] { - evaluate("abbbbc" rlike regEx, create_row("**")) + val regex = 'a.string.at(0) + evaluate("abbbbc" rlike regex, create_row("**")) } } - test("RegexReplace") { val row1 = create_row("100-200", "(\\d+)", "num") val row2 = create_row("100-200", "(\\d+)", "###") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/StringUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/StringUtilsSuite.scala index 2ffc18a8d14f..78fee5135c3a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/StringUtilsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/StringUtilsSuite.scala @@ -24,9 +24,9 @@ class StringUtilsSuite extends SparkFunSuite { test("escapeLikeRegex") { assert(escapeLikeRegex("abdef") === "(?s)\\Qa\\E\\Qb\\E\\Qd\\E\\Qe\\E\\Qf\\E") - assert(escapeLikeRegex("a\\__b") === "(?s)\\Qa\\E_.\\Qb\\E") + assert(escapeLikeRegex("a\\__b") === "(?s)\\Qa\\E\\Q_\\E.\\Qb\\E") assert(escapeLikeRegex("a_%b") === "(?s)\\Qa\\E..*\\Qb\\E") - assert(escapeLikeRegex("a%\\%b") === "(?s)\\Qa\\E.*%\\Qb\\E") + assert(escapeLikeRegex("a%\\%b") === "(?s)\\Qa\\E.*\\Q%\\E\\Qb\\E") assert(escapeLikeRegex("a%") === "(?s)\\Qa\\E.*") assert(escapeLikeRegex("**") === "(?s)\\Q*\\E\\Q*\\E") assert(escapeLikeRegex("a_b") === "(?s)\\Qa\\E.\\Qb\\E") From 0075562dd2551a31c35ca26922d6bd73cdb78ea4 Mon Sep 17 00:00:00 2001 From: Andrew Ash Date: Mon, 17 Apr 2017 17:56:33 -0700 Subject: [PATCH 293/512] Typo fix: distitrbuted -> distributed ## What changes were proposed in this pull request? Typo fix: distitrbuted -> distributed ## How was this patch tested? Existing tests Author: Andrew Ash Closes #17664 from ash211/patch-1. --- .../src/main/scala/org/apache/spark/deploy/yarn/Client.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 424bbca12319..b817570c0abf 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -577,7 +577,7 @@ private[spark] class Client( ).foreach { case (flist, resType, addToClasspath) => flist.foreach { file => val (_, localizedPath) = distribute(file, resType = resType) - // If addToClassPath, we ignore adding jar multiple times to distitrbuted cache. + // If addToClassPath, we ignore adding jar multiple times to distributed cache. if (addToClasspath) { if (localizedPath != null) { cachedSecondaryJarLinks += localizedPath From 33ea908af94152147e996a6dc8da41ada27d5af3 Mon Sep 17 00:00:00 2001 From: Jacek Laskowski Date: Mon, 17 Apr 2017 17:58:10 -0700 Subject: [PATCH 294/512] [TEST][MINOR] Replace repartitionBy with distribute in CollapseRepartitionSuite ## What changes were proposed in this pull request? Replace non-existent `repartitionBy` with `distribute` in `CollapseRepartitionSuite`. ## How was this patch tested? local build and `catalyst/testOnly *CollapseRepartitionSuite` Author: Jacek Laskowski Closes #17657 from jaceklaskowski/CollapseRepartitionSuite. --- .../optimizer/CollapseRepartitionSuite.scala | 21 +++++++++---------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseRepartitionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseRepartitionSuite.scala index 59d2dc46f00c..8cc8decd65de 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseRepartitionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseRepartitionSuite.scala @@ -106,8 +106,8 @@ class CollapseRepartitionSuite extends PlanTest { comparePlans(optimized2, correctAnswer) } - test("repartitionBy above repartition") { - // Always respects the top repartitionBy amd removes useless repartition + test("distribute above repartition") { + // Always respects the top distribute and removes useless repartition val query1 = testRelation .repartition(10) .distribute('a)(20) @@ -123,8 +123,8 @@ class CollapseRepartitionSuite extends PlanTest { comparePlans(optimized2, correctAnswer) } - test("repartitionBy above coalesce") { - // Always respects the top repartitionBy amd removes useless coalesce below repartition + test("distribute above coalesce") { + // Always respects the top distribute and removes useless coalesce below repartition val query1 = testRelation .coalesce(10) .distribute('a)(20) @@ -140,8 +140,8 @@ class CollapseRepartitionSuite extends PlanTest { comparePlans(optimized2, correctAnswer) } - test("repartition above repartitionBy") { - // Always respects the top repartition amd removes useless distribute below repartition + test("repartition above distribute") { + // Always respects the top repartition and removes useless distribute below repartition val query1 = testRelation .distribute('a)(10) .repartition(20) @@ -155,11 +155,10 @@ class CollapseRepartitionSuite extends PlanTest { comparePlans(optimized1, correctAnswer) comparePlans(optimized2, correctAnswer) - } - test("coalesce above repartitionBy") { - // Remove useless coalesce above repartition + test("coalesce above distribute") { + // Remove useless coalesce above distribute val query1 = testRelation .distribute('a)(10) .coalesce(20) @@ -180,8 +179,8 @@ class CollapseRepartitionSuite extends PlanTest { comparePlans(optimized2, correctAnswer2) } - test("collapse two adjacent repartitionBys into one") { - // Always respects the top repartitionBy + test("collapse two adjacent distributes into one") { + // Always respects the top distribute val query1 = testRelation .distribute('b)(10) .distribute('a)(20) From b0a1e93e93167b53058525a20a8b06f7df5f09a2 Mon Sep 17 00:00:00 2001 From: Felix Cheung Date: Mon, 17 Apr 2017 23:55:40 -0700 Subject: [PATCH 295/512] [SPARK-17647][SQL][FOLLOWUP][MINOR] fix typo ## What changes were proposed in this pull request? fix typo ## How was this patch tested? manual Author: Felix Cheung Closes #17663 from felixcheung/likedoctypo. --- .../spark/sql/catalyst/expressions/regexpExpressions.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index a36da8e94b3a..3fa84589e3c6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -79,7 +79,7 @@ abstract class StringRegexExpression extends BinaryExpression _ matches any one character in the input (similar to . in posix regular expressions) - % matches zero ore more characters in the input (similar to .* in posix regular + % matches zero or more characters in the input (similar to .* in posix regular expressions) The escape character is '\'. If an escape character precedes a special symbol or another From 07fd94e0d05e827fae65d6e0e1cb89e28c8f2771 Mon Sep 17 00:00:00 2001 From: Robert Stupp Date: Tue, 18 Apr 2017 11:02:43 +0100 Subject: [PATCH 296/512] [SPARK-20344][SCHEDULER] Duplicate call in FairSchedulableBuilder.addTaskSetManager ## What changes were proposed in this pull request? Eliminate the duplicate call to `Pool.getSchedulableByName()` in `FairSchedulableBuilder.addTaskSetManager` ## How was this patch tested? ./dev/run-tests Author: Robert Stupp Closes #17647 from snazy/20344-dup-call-master. --- .../spark/scheduler/SchedulableBuilder.scala | 32 +++++++++---------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala b/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala index 417103436144..5f3c280ec31e 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala @@ -181,23 +181,23 @@ private[spark] class FairSchedulableBuilder(val rootPool: Pool, conf: SparkConf) } override def addTaskSetManager(manager: Schedulable, properties: Properties) { - var poolName = DEFAULT_POOL_NAME - var parentPool = rootPool.getSchedulableByName(poolName) - if (properties != null) { - poolName = properties.getProperty(FAIR_SCHEDULER_PROPERTIES, DEFAULT_POOL_NAME) - parentPool = rootPool.getSchedulableByName(poolName) - if (parentPool == null) { - // we will create a new pool that user has configured in app - // instead of being defined in xml file - parentPool = new Pool(poolName, DEFAULT_SCHEDULING_MODE, - DEFAULT_MINIMUM_SHARE, DEFAULT_WEIGHT) - rootPool.addSchedulable(parentPool) - logWarning(s"A job was submitted with scheduler pool $poolName, which has not been " + - "configured. This can happen when the file that pools are read from isn't set, or " + - s"when that file doesn't contain $poolName. Created $poolName with default " + - s"configuration (schedulingMode: $DEFAULT_SCHEDULING_MODE, " + - s"minShare: $DEFAULT_MINIMUM_SHARE, weight: $DEFAULT_WEIGHT)") + val poolName = if (properties != null) { + properties.getProperty(FAIR_SCHEDULER_PROPERTIES, DEFAULT_POOL_NAME) + } else { + DEFAULT_POOL_NAME } + var parentPool = rootPool.getSchedulableByName(poolName) + if (parentPool == null) { + // we will create a new pool that user has configured in app + // instead of being defined in xml file + parentPool = new Pool(poolName, DEFAULT_SCHEDULING_MODE, + DEFAULT_MINIMUM_SHARE, DEFAULT_WEIGHT) + rootPool.addSchedulable(parentPool) + logWarning(s"A job was submitted with scheduler pool $poolName, which has not been " + + "configured. This can happen when the file that pools are read from isn't set, or " + + s"when that file doesn't contain $poolName. Created $poolName with default " + + s"configuration (schedulingMode: $DEFAULT_SCHEDULING_MODE, " + + s"minShare: $DEFAULT_MINIMUM_SHARE, weight: $DEFAULT_WEIGHT)") } parentPool.addSchedulable(manager) logInfo("Added task set " + manager.name + " tasks to pool " + poolName) From d4f10cbbe1b9d13e43d80a50d204781e1c5c2da9 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Tue, 18 Apr 2017 11:05:00 +0100 Subject: [PATCH 297/512] [SPARK-20343][BUILD] Force Avro 1.7.7 in sbt build to resolve build failure in SBT Hadoop 2.6 master on Jenkins ## What changes were proposed in this pull request? This PR proposes to force Avro's version to 1.7.7 in core to resolve the build failure as below: ``` [error] /home/jenkins/workspace/spark-master-test-sbt-hadoop-2.6/core/src/main/scala/org/apache/spark/serializer/GenericAvroSerializer.scala:123: value createDatumWriter is not a member of org.apache.avro.generic.GenericData [error] writerCache.getOrElseUpdate(schema, GenericData.get.createDatumWriter(schema)) [error] ``` https://amplab.cs.berkeley.edu/jenkins/view/Spark%20QA%20Test%20(Dashboard)/job/spark-master-test-sbt-hadoop-2.6/2770/consoleFull Note that this is a hack and should be removed in the future. ## How was this patch tested? I only tested this actually overrides the dependency. I tried many ways but I was unable to reproduce this in my local. Sean also tried the way I did but he was also unable to reproduce this. Please refer the comments in https://github.com/apache/spark/pull/17477#issuecomment-294094092 Author: hyukjinkwon Closes #17651 from HyukjinKwon/SPARK-20343-sbt. --- pom.xml | 1 + project/SparkBuild.scala | 14 ++++++++++++-- 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/pom.xml b/pom.xml index c1174593c192..14370d92a908 100644 --- a/pom.xml +++ b/pom.xml @@ -142,6 +142,7 @@ 2.4.0 2.0.8 3.1.2 + 1.7.7 hadoop2 0.9.3 diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index e52baf51aed1..77dae289f775 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -318,8 +318,8 @@ object SparkBuild extends PomBuild { enable(MimaBuild.mimaSettings(sparkHome, x))(x) } - /* Generate and pick the spark build info from extra-resources */ - enable(Core.settings)(core) + /* Generate and pick the spark build info from extra-resources and override a dependency */ + enable(Core.settings ++ CoreDependencyOverrides.settings)(core) /* Unsafe settings */ enable(Unsafe.settings)(unsafe) @@ -443,6 +443,16 @@ object DockerIntegrationTests { ) } +/** + * Overrides to work around sbt's dependency resolution being different from Maven's in Unidoc. + * + * Note that, this is a hack that should be removed in the future. See SPARK-20343 + */ +object CoreDependencyOverrides { + lazy val settings = Seq( + dependencyOverrides += "org.apache.avro" % "avro" % "1.7.7") +} + /** * Overrides to work around sbt's dependency resolution being different from Maven's. */ From 321b4f03bc983c582a3c6259019c077cdfac9d26 Mon Sep 17 00:00:00 2001 From: wangzhenhua Date: Tue, 18 Apr 2017 20:12:21 +0800 Subject: [PATCH 298/512] [SPARK-20366][SQL] Fix recursive join reordering: inside joins are not reordered ## What changes were proposed in this pull request? If a plan has multi-level successive joins, e.g.: ``` Join / \ Union t5 / \ Join t4 / \ Join t3 / \ t1 t2 ``` Currently we fail to reorder the inside joins, i.e. t1, t2, t3. In join reorder, we use `OrderedJoin` to indicate a join has been ordered, such that when transforming down the plan, these joins don't need to be rerodered again. But there's a problem in the definition of `OrderedJoin`: The real join node is a parameter, but not a child. This breaks the transform procedure because `mapChildren` applies transform function on parameters which should be children. In this patch, we change `OrderedJoin` to a class having the same structure as a join node. ## How was this patch tested? Add a corresponding test case. Author: wangzhenhua Closes #17668 from wzhfy/recursiveReorder. --- .../optimizer/CostBasedJoinReorder.scala | 22 +++++---- .../catalyst/optimizer/JoinReorderSuite.scala | 49 +++++++++++++++++-- 2 files changed, 58 insertions(+), 13 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala index c704c2e6d36b..51eca6ca3376 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala @@ -21,7 +21,7 @@ import scala.collection.mutable import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeSet, Expression, PredicateHelper} -import org.apache.spark.sql.catalyst.plans.{Inner, InnerLike} +import org.apache.spark.sql.catalyst.plans.{Inner, InnerLike, JoinType} import org.apache.spark.sql.catalyst.plans.logical.{BinaryNode, Join, LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.internal.SQLConf @@ -47,7 +47,7 @@ case class CostBasedJoinReorder(conf: SQLConf) extends Rule[LogicalPlan] with Pr } // After reordering is finished, convert OrderedJoin back to Join result transformDown { - case oj: OrderedJoin => oj.join + case OrderedJoin(left, right, jt, cond) => Join(left, right, jt, cond) } } } @@ -87,22 +87,24 @@ case class CostBasedJoinReorder(conf: SQLConf) extends Rule[LogicalPlan] with Pr } private def replaceWithOrderedJoin(plan: LogicalPlan): LogicalPlan = plan match { - case j @ Join(left, right, _: InnerLike, Some(cond)) => + case j @ Join(left, right, jt: InnerLike, Some(cond)) => val replacedLeft = replaceWithOrderedJoin(left) val replacedRight = replaceWithOrderedJoin(right) - OrderedJoin(j.copy(left = replacedLeft, right = replacedRight)) + OrderedJoin(replacedLeft, replacedRight, jt, Some(cond)) case p @ Project(projectList, j @ Join(_, _, _: InnerLike, Some(cond))) => p.copy(child = replaceWithOrderedJoin(j)) case _ => plan } +} - /** This is a wrapper class for a join node that has been ordered. */ - private case class OrderedJoin(join: Join) extends BinaryNode { - override def left: LogicalPlan = join.left - override def right: LogicalPlan = join.right - override def output: Seq[Attribute] = join.output - } +/** This is a mimic class for a join node that has been ordered. */ +case class OrderedJoin( + left: LogicalPlan, + right: LogicalPlan, + joinType: JoinType, + condition: Option[Expression]) extends BinaryNode { + override def output: Seq[Attribute] = left.output ++ right.output } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala index 1922eb30fdce..71db4e2e0ec4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala @@ -25,13 +25,12 @@ import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.catalyst.statsEstimation.{StatsEstimationTestBase, StatsTestPlan} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.internal.SQLConf.{CASE_SENSITIVE, CBO_ENABLED, JOIN_REORDER_ENABLED} +import org.apache.spark.sql.internal.SQLConf.{CBO_ENABLED, JOIN_REORDER_ENABLED} class JoinReorderSuite extends PlanTest with StatsEstimationTestBase { - override val conf = new SQLConf().copy( - CASE_SENSITIVE -> true, CBO_ENABLED -> true, JOIN_REORDER_ENABLED -> true) + override val conf = new SQLConf().copy(CBO_ENABLED -> true, JOIN_REORDER_ENABLED -> true) object Optimize extends RuleExecutor[LogicalPlan] { val batches = @@ -212,6 +211,50 @@ class JoinReorderSuite extends PlanTest with StatsEstimationTestBase { } } + test("reorder recursively") { + // Original order: + // Join + // / \ + // Union t5 + // / \ + // Join t4 + // / \ + // Join t3 + // / \ + // t1 t2 + val bottomJoins = + t1.join(t2).join(t3).where((nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")) && + (nameToAttr("t1.v-1-10") === nameToAttr("t3.v-1-100"))) + .select(nameToAttr("t1.v-1-10")) + + val originalPlan = bottomJoins + .union(t4.select(nameToAttr("t4.v-1-10"))) + .join(t5, Inner, Some(nameToAttr("t1.v-1-10") === nameToAttr("t5.v-1-5"))) + + // Should be able to reorder the bottom part. + // Best order: + // Join + // / \ + // Union t5 + // / \ + // Join t4 + // / \ + // Join t2 + // / \ + // t1 t3 + val bestBottomPlan = + t1.join(t3, Inner, Some(nameToAttr("t1.v-1-10") === nameToAttr("t3.v-1-100"))) + .select(nameToAttr("t1.k-1-2"), nameToAttr("t1.v-1-10")) + .join(t2, Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5"))) + .select(nameToAttr("t1.v-1-10")) + + val bestPlan = bestBottomPlan + .union(t4.select(nameToAttr("t4.v-1-10"))) + .join(t5, Inner, Some(nameToAttr("t1.v-1-10") === nameToAttr("t5.v-1-5"))) + + assertEqualPlans(originalPlan, bestPlan) + } + private def assertEqualPlans( originalPlan: LogicalPlan, groundTruthBestPlan: LogicalPlan): Unit = { From 1f81dda37cfc2049fabd6abd93ef3720d0aa03ea Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=83=AD=E5=B0=8F=E9=BE=99=2010207633?= Date: Tue, 18 Apr 2017 10:02:21 -0700 Subject: [PATCH 299/512] [SPARK-20354][CORE][REST-API] When I request access to the 'http: //ip:port/api/v1/applications' link, return 'sparkUser' is empty in REST API. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? When I request access to the 'http: //ip:port/api/v1/applications' link, get the json. I need the 'sparkUser' field specific value, because my Spark big data management platform needs to filter through this field which user submits the application to facilitate my administration and query, but the current return of the json string is empty, causing me this Function can not be achieved, that is, I do not know who the specific application is submitted by this REST Api. **current return json:** [ { "id" : "app-20170417152053-0000", "name" : "KafkaWordCount", "attempts" : [ { "startTime" : "2017-04-17T07:20:51.395GMT", "endTime" : "1969-12-31T23:59:59.999GMT", "lastUpdated" : "2017-04-17T07:20:51.395GMT", "duration" : 0, **"sparkUser" : "",** "completed" : false, "endTimeEpoch" : -1, "startTimeEpoch" : 1492413651395, "lastUpdatedEpoch" : 1492413651395 } ] } ] **When I fix this question, return json:** [ { "id" : "app-20170417154201-0000", "name" : "KafkaWordCount", "attempts" : [ { "startTime" : "2017-04-17T07:41:57.335GMT", "endTime" : "1969-12-31T23:59:59.999GMT", "lastUpdated" : "2017-04-17T07:41:57.335GMT", "duration" : 0, **"sparkUser" : "mr",** "completed" : false, "startTimeEpoch" : 1492414917335, "endTimeEpoch" : -1, "lastUpdatedEpoch" : 1492414917335 } ] } ] ## How was this patch tested? manual tests Please review http://spark.apache.org/contributing.html before opening a pull request. Author: 郭小龙 10207633 Author: guoxiaolong Author: guoxiaolongzte Closes #17656 from guoxiaolongzte/SPARK-20354. --- core/src/main/scala/org/apache/spark/ui/SparkUI.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala index 7d31ac54a717..bf4cf79e9faa 100644 --- a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala @@ -117,7 +117,7 @@ private[spark] class SparkUI private ( endTime = new Date(-1), duration = 0, lastUpdated = new Date(startTime), - sparkUser = "", + sparkUser = getSparkUser, completed = false )) )) From f654b39a63d4f9b118733733c7ed2a1b58649e3d Mon Sep 17 00:00:00 2001 From: Kyle Kelley Date: Tue, 18 Apr 2017 12:35:27 -0700 Subject: [PATCH 300/512] [SPARK-20360][PYTHON] reprs for interpreters ## What changes were proposed in this pull request? Establishes a very minimal `_repr_html_` for PySpark's `SparkContext`. ## How was this patch tested? nteract: ![screen shot 2017-04-17 at 3 41 29 pm](https://cloud.githubusercontent.com/assets/836375/25107701/d57090ba-2385-11e7-8147-74bc2c50a41b.png) Jupyter: ![screen shot 2017-04-17 at 3 53 19 pm](https://cloud.githubusercontent.com/assets/836375/25107725/05bf1fe8-2386-11e7-93e1-07a20c917dde.png) Hydrogen: ![screen shot 2017-04-17 at 3 49 55 pm](https://cloud.githubusercontent.com/assets/836375/25107664/a75e1ddc-2385-11e7-8477-258661833007.png) Author: Kyle Kelley Closes #17662 from rgbkrk/repr. --- python/pyspark/context.py | 26 ++++++++++++++++++++++++++ python/pyspark/sql/session.py | 11 +++++++++++ 2 files changed, 37 insertions(+) diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 2961cda553d6..3be07325f416 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -240,6 +240,32 @@ def signal_handler(signal, frame): if isinstance(threading.current_thread(), threading._MainThread): signal.signal(signal.SIGINT, signal_handler) + def __repr__(self): + return "".format( + master=self.master, + appName=self.appName, + ) + + def _repr_html_(self): + return """ +
    +

    SparkContext

    + +

    Spark UI

    + +
    +
    Version
    +
    v{sc.version}
    +
    Master
    +
    {sc.master}
    +
    AppName
    +
    {sc.appName}
    +
    +
    + """.format( + sc=self + ) + def _initialize_context(self, jconf): """ Initialize SparkContext in function to allow subclass specific initialization diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index 9f4772eec9f2..c1bf2bd76fb7 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -221,6 +221,17 @@ def __init__(self, sparkContext, jsparkSession=None): or SparkSession._instantiatedSession._sc._jsc is None: SparkSession._instantiatedSession = self + def _repr_html_(self): + return """ +
    +

    SparkSession - {catalogImplementation}

    + {sc_HTML} +
    + """.format( + catalogImplementation=self.conf.get("spark.sql.catalogImplementation"), + sc_HTML=self.sparkContext._repr_html_() + ) + @since(2.0) def newSession(self): """ From 74aa0df8f7f132b62754e5159262e4a5b9b641ab Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 18 Apr 2017 16:10:40 -0700 Subject: [PATCH 301/512] [SPARK-20377][SS] Fix JavaStructuredSessionization example ## What changes were proposed in this pull request? Extra accessors in java bean class causes incorrect encoder generation, which corrupted the state when using timeouts. ## How was this patch tested? manually ran the example Author: Tathagata Das Closes #17676 from tdas/SPARK-20377. --- .../sql/streaming/JavaStructuredSessionization.java | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/examples/src/main/java/org/apache/spark/examples/sql/streaming/JavaStructuredSessionization.java b/examples/src/main/java/org/apache/spark/examples/sql/streaming/JavaStructuredSessionization.java index da3a5dfe8628..d3c8516882fa 100644 --- a/examples/src/main/java/org/apache/spark/examples/sql/streaming/JavaStructuredSessionization.java +++ b/examples/src/main/java/org/apache/spark/examples/sql/streaming/JavaStructuredSessionization.java @@ -76,8 +76,6 @@ public Iterator call(LineWithTimestamp lineWithTimestamp) throws Exceptio for (String word : lineWithTimestamp.getLine().split(" ")) { eventList.add(new Event(word, lineWithTimestamp.getTimestamp())); } - System.out.println( - "Number of events from " + lineWithTimestamp.getLine() + " = " + eventList.size()); return eventList.iterator(); } }; @@ -100,7 +98,7 @@ public Iterator call(LineWithTimestamp lineWithTimestamp) throws Exceptio // If timed out, then remove session and send final update if (state.hasTimedOut()) { SessionUpdate finalUpdate = new SessionUpdate( - sessionId, state.get().getDurationMs(), state.get().getNumEvents(), true); + sessionId, state.get().calculateDuration(), state.get().getNumEvents(), true); state.remove(); return finalUpdate; @@ -133,7 +131,7 @@ public Iterator call(LineWithTimestamp lineWithTimestamp) throws Exceptio // Set timeout such that the session will be expired if no data received for 10 seconds state.setTimeoutDuration("10 seconds"); return new SessionUpdate( - sessionId, state.get().getDurationMs(), state.get().getNumEvents(), false); + sessionId, state.get().calculateDuration(), state.get().getNumEvents(), false); } } }; @@ -215,7 +213,8 @@ public void setStartTimestampMs(long startTimestampMs) { public long getEndTimestampMs() { return endTimestampMs; } public void setEndTimestampMs(long endTimestampMs) { this.endTimestampMs = endTimestampMs; } - public long getDurationMs() { return endTimestampMs - startTimestampMs; } + public long calculateDuration() { return endTimestampMs - startTimestampMs; } + @Override public String toString() { return "SessionInfo(numEvents = " + numEvents + ", timestamps = " + startTimestampMs + " to " + endTimestampMs + ")"; From e468a96c404eb54261ab219734f67dc2f5b06dc0 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Wed, 19 Apr 2017 10:58:05 +0800 Subject: [PATCH 302/512] [SPARK-20254][SQL] Remove unnecessary data conversion for Dataset with primitive array ## What changes were proposed in this pull request? This PR elminates unnecessary data conversion, which is introduced by SPARK-19716, for Dataset with primitve array in the generated Java code. When we run the following example program, now we get the Java code "Without this PR". In this code, lines 56-82 are unnecessary since the primitive array in ArrayData can be converted into Java primitive array by using ``toDoubleArray()`` method. ``GenericArrayData`` is not required. ```java val ds = sparkContext.parallelize(Seq(Array(1.1, 2.2)), 1).toDS.cache ds.count ds.map(e => e).show ``` Without this PR ``` == Parsed Logical Plan == 'SerializeFromObject [staticinvoke(class org.apache.spark.sql.catalyst.expressions.UnsafeArrayData, ArrayType(DoubleType,false), fromPrimitiveArray, input[0, [D, true], true) AS value#25] +- 'MapElements , class [D, [StructField(value,ArrayType(DoubleType,false),true)], obj#24: [D +- 'DeserializeToObject unresolveddeserializer(unresolvedmapobjects(, getcolumnbyordinal(0, ArrayType(DoubleType,false)), None).toDoubleArray), obj#23: [D +- SerializeFromObject [staticinvoke(class org.apache.spark.sql.catalyst.expressions.UnsafeArrayData, ArrayType(DoubleType,false), fromPrimitiveArray, input[0, [D, true], true) AS value#2] +- ExternalRDD [obj#1] == Analyzed Logical Plan == value: array SerializeFromObject [staticinvoke(class org.apache.spark.sql.catalyst.expressions.UnsafeArrayData, ArrayType(DoubleType,false), fromPrimitiveArray, input[0, [D, true], true) AS value#25] +- MapElements , class [D, [StructField(value,ArrayType(DoubleType,false),true)], obj#24: [D +- DeserializeToObject mapobjects(MapObjects_loopValue5, MapObjects_loopIsNull5, DoubleType, assertnotnull(lambdavariable(MapObjects_loopValue5, MapObjects_loopIsNull5, DoubleType, true), - array element class: "scala.Double", - root class: "scala.Array"), value#2, None, MapObjects_builderValue5).toDoubleArray, obj#23: [D +- SerializeFromObject [staticinvoke(class org.apache.spark.sql.catalyst.expressions.UnsafeArrayData, ArrayType(DoubleType,false), fromPrimitiveArray, input[0, [D, true], true) AS value#2] +- ExternalRDD [obj#1] == Optimized Logical Plan == SerializeFromObject [staticinvoke(class org.apache.spark.sql.catalyst.expressions.UnsafeArrayData, ArrayType(DoubleType,false), fromPrimitiveArray, input[0, [D, true], true) AS value#25] +- MapElements , class [D, [StructField(value,ArrayType(DoubleType,false),true)], obj#24: [D +- DeserializeToObject mapobjects(MapObjects_loopValue5, MapObjects_loopIsNull5, DoubleType, assertnotnull(lambdavariable(MapObjects_loopValue5, MapObjects_loopIsNull5, DoubleType, true), - array element class: "scala.Double", - root class: "scala.Array"), value#2, None, MapObjects_builderValue5).toDoubleArray, obj#23: [D +- InMemoryRelation [value#2], true, 10000, StorageLevel(disk, memory, deserialized, 1 replicas) +- *SerializeFromObject [staticinvoke(class org.apache.spark.sql.catalyst.expressions.UnsafeArrayData, ArrayType(DoubleType,false), fromPrimitiveArray, input[0, [D, true], true) AS value#2] +- Scan ExternalRDDScan[obj#1] == Physical Plan == *SerializeFromObject [staticinvoke(class org.apache.spark.sql.catalyst.expressions.UnsafeArrayData, ArrayType(DoubleType,false), fromPrimitiveArray, input[0, [D, true], true) AS value#25] +- *MapElements , obj#24: [D +- *DeserializeToObject mapobjects(MapObjects_loopValue5, MapObjects_loopIsNull5, DoubleType, assertnotnull(lambdavariable(MapObjects_loopValue5, MapObjects_loopIsNull5, DoubleType, true), - array element class: "scala.Double", - root class: "scala.Array"), value#2, None, MapObjects_builderValue5).toDoubleArray, obj#23: [D +- InMemoryTableScan [value#2] +- InMemoryRelation [value#2], true, 10000, StorageLevel(disk, memory, deserialized, 1 replicas) +- *SerializeFromObject [staticinvoke(class org.apache.spark.sql.catalyst.expressions.UnsafeArrayData, ArrayType(DoubleType,false), fromPrimitiveArray, input[0, [D, true], true) AS value#2] +- Scan ExternalRDDScan[obj#1] ``` ```java /* 050 */ protected void processNext() throws java.io.IOException { /* 051 */ while (inputadapter_input.hasNext() && !stopEarly()) { /* 052 */ InternalRow inputadapter_row = (InternalRow) inputadapter_input.next(); /* 053 */ boolean inputadapter_isNull = inputadapter_row.isNullAt(0); /* 054 */ ArrayData inputadapter_value = inputadapter_isNull ? null : (inputadapter_row.getArray(0)); /* 055 */ /* 056 */ ArrayData deserializetoobject_value1 = null; /* 057 */ /* 058 */ if (!inputadapter_isNull) { /* 059 */ int deserializetoobject_dataLength = inputadapter_value.numElements(); /* 060 */ /* 061 */ Double[] deserializetoobject_convertedArray = null; /* 062 */ deserializetoobject_convertedArray = new Double[deserializetoobject_dataLength]; /* 063 */ /* 064 */ int deserializetoobject_loopIndex = 0; /* 065 */ while (deserializetoobject_loopIndex < deserializetoobject_dataLength) { /* 066 */ MapObjects_loopValue2 = (double) (inputadapter_value.getDouble(deserializetoobject_loopIndex)); /* 067 */ MapObjects_loopIsNull2 = inputadapter_value.isNullAt(deserializetoobject_loopIndex); /* 068 */ /* 069 */ if (MapObjects_loopIsNull2) { /* 070 */ throw new RuntimeException(((java.lang.String) references[0])); /* 071 */ } /* 072 */ if (false) { /* 073 */ deserializetoobject_convertedArray[deserializetoobject_loopIndex] = null; /* 074 */ } else { /* 075 */ deserializetoobject_convertedArray[deserializetoobject_loopIndex] = MapObjects_loopValue2; /* 076 */ } /* 077 */ /* 078 */ deserializetoobject_loopIndex += 1; /* 079 */ } /* 080 */ /* 081 */ deserializetoobject_value1 = new org.apache.spark.sql.catalyst.util.GenericArrayData(deserializetoobject_convertedArray); /*###*/ /* 082 */ } /* 083 */ boolean deserializetoobject_isNull = true; /* 084 */ double[] deserializetoobject_value = null; /* 085 */ if (!inputadapter_isNull) { /* 086 */ deserializetoobject_isNull = false; /* 087 */ if (!deserializetoobject_isNull) { /* 088 */ Object deserializetoobject_funcResult = null; /* 089 */ deserializetoobject_funcResult = deserializetoobject_value1.toDoubleArray(); /* 090 */ if (deserializetoobject_funcResult == null) { /* 091 */ deserializetoobject_isNull = true; /* 092 */ } else { /* 093 */ deserializetoobject_value = (double[]) deserializetoobject_funcResult; /* 094 */ } /* 095 */ /* 096 */ } /* 097 */ deserializetoobject_isNull = deserializetoobject_value == null; /* 098 */ } /* 099 */ /* 100 */ boolean mapelements_isNull = true; /* 101 */ double[] mapelements_value = null; /* 102 */ if (!false) { /* 103 */ mapelements_resultIsNull = false; /* 104 */ /* 105 */ if (!mapelements_resultIsNull) { /* 106 */ mapelements_resultIsNull = deserializetoobject_isNull; /* 107 */ mapelements_argValue = deserializetoobject_value; /* 108 */ } /* 109 */ /* 110 */ mapelements_isNull = mapelements_resultIsNull; /* 111 */ if (!mapelements_isNull) { /* 112 */ Object mapelements_funcResult = null; /* 113 */ mapelements_funcResult = ((scala.Function1) references[1]).apply(mapelements_argValue); /* 114 */ if (mapelements_funcResult == null) { /* 115 */ mapelements_isNull = true; /* 116 */ } else { /* 117 */ mapelements_value = (double[]) mapelements_funcResult; /* 118 */ } /* 119 */ /* 120 */ } /* 121 */ mapelements_isNull = mapelements_value == null; /* 122 */ } /* 123 */ /* 124 */ serializefromobject_resultIsNull = false; /* 125 */ /* 126 */ if (!serializefromobject_resultIsNull) { /* 127 */ serializefromobject_resultIsNull = mapelements_isNull; /* 128 */ serializefromobject_argValue = mapelements_value; /* 129 */ } /* 130 */ /* 131 */ boolean serializefromobject_isNull = serializefromobject_resultIsNull; /* 132 */ final ArrayData serializefromobject_value = serializefromobject_resultIsNull ? null : org.apache.spark.sql.catalyst.expressions.UnsafeArrayData.fromPrimitiveArray(serializefromobject_argValue); /* 133 */ serializefromobject_isNull = serializefromobject_value == null; /* 134 */ serializefromobject_holder.reset(); /* 135 */ /* 136 */ serializefromobject_rowWriter.zeroOutNullBytes(); /* 137 */ /* 138 */ if (serializefromobject_isNull) { /* 139 */ serializefromobject_rowWriter.setNullAt(0); /* 140 */ } else { /* 141 */ // Remember the current cursor so that we can calculate how many bytes are /* 142 */ // written later. /* 143 */ final int serializefromobject_tmpCursor = serializefromobject_holder.cursor; /* 144 */ /* 145 */ if (serializefromobject_value instanceof UnsafeArrayData) { /* 146 */ final int serializefromobject_sizeInBytes = ((UnsafeArrayData) serializefromobject_value).getSizeInBytes(); /* 147 */ // grow the global buffer before writing data. /* 148 */ serializefromobject_holder.grow(serializefromobject_sizeInBytes); /* 149 */ ((UnsafeArrayData) serializefromobject_value).writeToMemory(serializefromobject_holder.buffer, serializefromobject_holder.cursor); /* 150 */ serializefromobject_holder.cursor += serializefromobject_sizeInBytes; /* 151 */ /* 152 */ } else { /* 153 */ final int serializefromobject_numElements = serializefromobject_value.numElements(); /* 154 */ serializefromobject_arrayWriter.initialize(serializefromobject_holder, serializefromobject_numElements, 8); /* 155 */ /* 156 */ for (int serializefromobject_index = 0; serializefromobject_index < serializefromobject_numElements; serializefromobject_index++) { /* 157 */ if (serializefromobject_value.isNullAt(serializefromobject_index)) { /* 158 */ serializefromobject_arrayWriter.setNullDouble(serializefromobject_index); /* 159 */ } else { /* 160 */ final double serializefromobject_element = serializefromobject_value.getDouble(serializefromobject_index); /* 161 */ serializefromobject_arrayWriter.write(serializefromobject_index, serializefromobject_element); /* 162 */ } /* 163 */ } /* 164 */ } /* 165 */ /* 166 */ serializefromobject_rowWriter.setOffsetAndSize(0, serializefromobject_tmpCursor, serializefromobject_holder.cursor - serializefromobject_tmpCursor); /* 167 */ } /* 168 */ serializefromobject_result.setTotalSize(serializefromobject_holder.totalSize()); /* 169 */ append(serializefromobject_result); /* 170 */ if (shouldStop()) return; /* 171 */ } /* 172 */ } ``` With this PR (eliminated lines 56-62 in the above code) ```java /* 047 */ protected void processNext() throws java.io.IOException { /* 048 */ while (inputadapter_input.hasNext() && !stopEarly()) { /* 049 */ InternalRow inputadapter_row = (InternalRow) inputadapter_input.next(); /* 050 */ boolean inputadapter_isNull = inputadapter_row.isNullAt(0); /* 051 */ ArrayData inputadapter_value = inputadapter_isNull ? null : (inputadapter_row.getArray(0)); /* 052 */ /* 053 */ boolean deserializetoobject_isNull = true; /* 054 */ double[] deserializetoobject_value = null; /* 055 */ if (!inputadapter_isNull) { /* 056 */ deserializetoobject_isNull = false; /* 057 */ if (!deserializetoobject_isNull) { /* 058 */ Object deserializetoobject_funcResult = null; /* 059 */ deserializetoobject_funcResult = inputadapter_value.toDoubleArray(); /* 060 */ if (deserializetoobject_funcResult == null) { /* 061 */ deserializetoobject_isNull = true; /* 062 */ } else { /* 063 */ deserializetoobject_value = (double[]) deserializetoobject_funcResult; /* 064 */ } /* 065 */ /* 066 */ } /* 067 */ deserializetoobject_isNull = deserializetoobject_value == null; /* 068 */ } /* 069 */ /* 070 */ boolean mapelements_isNull = true; /* 071 */ double[] mapelements_value = null; /* 072 */ if (!false) { /* 073 */ mapelements_resultIsNull = false; /* 074 */ /* 075 */ if (!mapelements_resultIsNull) { /* 076 */ mapelements_resultIsNull = deserializetoobject_isNull; /* 077 */ mapelements_argValue = deserializetoobject_value; /* 078 */ } /* 079 */ /* 080 */ mapelements_isNull = mapelements_resultIsNull; /* 081 */ if (!mapelements_isNull) { /* 082 */ Object mapelements_funcResult = null; /* 083 */ mapelements_funcResult = ((scala.Function1) references[0]).apply(mapelements_argValue); /* 084 */ if (mapelements_funcResult == null) { /* 085 */ mapelements_isNull = true; /* 086 */ } else { /* 087 */ mapelements_value = (double[]) mapelements_funcResult; /* 088 */ } /* 089 */ /* 090 */ } /* 091 */ mapelements_isNull = mapelements_value == null; /* 092 */ } /* 093 */ /* 094 */ serializefromobject_resultIsNull = false; /* 095 */ /* 096 */ if (!serializefromobject_resultIsNull) { /* 097 */ serializefromobject_resultIsNull = mapelements_isNull; /* 098 */ serializefromobject_argValue = mapelements_value; /* 099 */ } /* 100 */ /* 101 */ boolean serializefromobject_isNull = serializefromobject_resultIsNull; /* 102 */ final ArrayData serializefromobject_value = serializefromobject_resultIsNull ? null : org.apache.spark.sql.catalyst.expressions.UnsafeArrayData.fromPrimitiveArray(serializefromobject_argValue); /* 103 */ serializefromobject_isNull = serializefromobject_value == null; /* 104 */ serializefromobject_holder.reset(); /* 105 */ /* 106 */ serializefromobject_rowWriter.zeroOutNullBytes(); /* 107 */ /* 108 */ if (serializefromobject_isNull) { /* 109 */ serializefromobject_rowWriter.setNullAt(0); /* 110 */ } else { /* 111 */ // Remember the current cursor so that we can calculate how many bytes are /* 112 */ // written later. /* 113 */ final int serializefromobject_tmpCursor = serializefromobject_holder.cursor; /* 114 */ /* 115 */ if (serializefromobject_value instanceof UnsafeArrayData) { /* 116 */ final int serializefromobject_sizeInBytes = ((UnsafeArrayData) serializefromobject_value).getSizeInBytes(); /* 117 */ // grow the global buffer before writing data. /* 118 */ serializefromobject_holder.grow(serializefromobject_sizeInBytes); /* 119 */ ((UnsafeArrayData) serializefromobject_value).writeToMemory(serializefromobject_holder.buffer, serializefromobject_holder.cursor); /* 120 */ serializefromobject_holder.cursor += serializefromobject_sizeInBytes; /* 121 */ /* 122 */ } else { /* 123 */ final int serializefromobject_numElements = serializefromobject_value.numElements(); /* 124 */ serializefromobject_arrayWriter.initialize(serializefromobject_holder, serializefromobject_numElements, 8); /* 125 */ /* 126 */ for (int serializefromobject_index = 0; serializefromobject_index < serializefromobject_numElements; serializefromobject_index++) { /* 127 */ if (serializefromobject_value.isNullAt(serializefromobject_index)) { /* 128 */ serializefromobject_arrayWriter.setNullDouble(serializefromobject_index); /* 129 */ } else { /* 130 */ final double serializefromobject_element = serializefromobject_value.getDouble(serializefromobject_index); /* 131 */ serializefromobject_arrayWriter.write(serializefromobject_index, serializefromobject_element); /* 132 */ } /* 133 */ } /* 134 */ } /* 135 */ /* 136 */ serializefromobject_rowWriter.setOffsetAndSize(0, serializefromobject_tmpCursor, serializefromobject_holder.cursor - serializefromobject_tmpCursor); /* 137 */ } /* 138 */ serializefromobject_result.setTotalSize(serializefromobject_holder.totalSize()); /* 139 */ append(serializefromobject_result); /* 140 */ if (shouldStop()) return; /* 141 */ } /* 142 */ } ``` ## How was this patch tested? Add test suites into `DatasetPrimitiveSuite` Author: Kazuaki Ishizaki Closes #17568 from kiszk/SPARK-20254. --- .../sql/catalyst/analysis/Analyzer.scala | 4 +- .../expressions/objects/objects.scala | 5 +- .../sql/catalyst/optimizer/Optimizer.scala | 3 +- .../sql/catalyst/optimizer/expressions.scala | 3 + .../sql/catalyst/optimizer/objects.scala | 13 ++++ .../optimizer/EliminateMapObjectsSuite.scala | 62 +++++++++++++++++++ 6 files changed, 86 insertions(+), 4 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateMapObjectsSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 9816b33ae8df..d9f36f7f874d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -2230,8 +2230,8 @@ class Analyzer( val result = resolved transformDown { case UnresolvedMapObjects(func, inputData, cls) if inputData.resolved => inputData.dataType match { - case ArrayType(et, _) => - val expr = MapObjects(func, inputData, et, cls) transformUp { + case ArrayType(et, cn) => + val expr = MapObjects(func, inputData, et, cn, cls) transformUp { case UnresolvedExtractValue(child, fieldName) if child.resolved => ExtractValue(child, fieldName, resolver) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index f446c3e4a75f..1a202ecf745c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -451,6 +451,8 @@ object MapObjects { * @param function The function applied on the collection elements. * @param inputData An expression that when evaluated returns a collection object. * @param elementType The data type of elements in the collection. + * @param elementNullable When false, indicating elements in the collection are always + * non-null value. * @param customCollectionCls Class of the resulting collection (returning ObjectType) * or None (returning ArrayType) */ @@ -458,11 +460,12 @@ object MapObjects { function: Expression => Expression, inputData: Expression, elementType: DataType, + elementNullable: Boolean = true, customCollectionCls: Option[Class[_]] = None): MapObjects = { val id = curId.getAndIncrement() val loopValue = s"MapObjects_loopValue$id" val loopIsNull = s"MapObjects_loopIsNull$id" - val loopVar = LambdaVariable(loopValue, loopIsNull, elementType) + val loopVar = LambdaVariable(loopValue, loopIsNull, elementType, elementNullable) MapObjects( loopValue, loopIsNull, elementType, function(loopVar), inputData, customCollectionCls) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index d221b0611a89..dd768d18e858 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -119,7 +119,8 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: SQLConf) CostBasedJoinReorder(conf)) :: Batch("Decimal Optimizations", fixedPoint, DecimalAggregates(conf)) :: - Batch("Typed Filter Optimization", fixedPoint, + Batch("Object Expressions Optimization", fixedPoint, + EliminateMapObjects, CombineTypedFilters) :: Batch("LocalRelation", fixedPoint, ConvertToLocalRelation, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 8445ee06bd89..ea2c5d241d8d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral} +import org.apache.spark.sql.catalyst.expressions.objects.AssertNotNull import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ @@ -368,6 +369,8 @@ case class NullPropagation(conf: SQLConf) extends Rule[LogicalPlan] { case EqualNullSafe(Literal(null, _), r) => IsNull(r) case EqualNullSafe(l, Literal(null, _)) => IsNull(l) + case AssertNotNull(c, _) if !c.nullable => c + // For Coalesce, remove null literals. case e @ Coalesce(children) => val newChildren = children.filterNot(isNullLiteral) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala index 257dbfac8c3e..8cdc6425bcad 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.api.java.function.FilterFunction import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.objects._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ @@ -96,3 +97,15 @@ object CombineTypedFilters extends Rule[LogicalPlan] { } } } + +/** + * Removes MapObjects when the following conditions are satisfied + * 1. Mapobject(... lambdavariable(..., false) ...), which means types for input and output + * are primitive types with non-nullable + * 2. no custom collection class specified representation of data item. + */ +object EliminateMapObjects extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + case MapObjects(_, _, _, LambdaVariable(_, _, _, false), inputData, None) => inputData + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateMapObjectsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateMapObjectsSuite.scala new file mode 100644 index 000000000000..d4f37e2a5e87 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateMapObjectsSuite.scala @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.catalyst.expressions.AttributeReference +import org.apache.spark.sql.catalyst.expressions.objects.Invoke +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{DeserializeToObject, LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.types._ + +class EliminateMapObjectsSuite extends PlanTest { + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = { + Batch("EliminateMapObjects", FixedPoint(50), + NullPropagation(conf), + SimplifyCasts, + EliminateMapObjects) :: Nil + } + } + + implicit private def intArrayEncoder = ExpressionEncoder[Array[Int]]() + implicit private def doubleArrayEncoder = ExpressionEncoder[Array[Double]]() + + test("SPARK-20254: Remove unnecessary data conversion for primitive array") { + val intObjType = ObjectType(classOf[Array[Int]]) + val intInput = LocalRelation('a.array(ArrayType(IntegerType, false))) + val intQuery = intInput.deserialize[Array[Int]].analyze + val intOptimized = Optimize.execute(intQuery) + val intExpected = DeserializeToObject( + Invoke(intInput.output(0), "toIntArray", intObjType, Nil, true, false), + AttributeReference("obj", intObjType, true)(), intInput) + comparePlans(intOptimized, intExpected) + + val doubleObjType = ObjectType(classOf[Array[Double]]) + val doubleInput = LocalRelation('a.array(ArrayType(DoubleType, false))) + val doubleQuery = doubleInput.deserialize[Array[Double]].analyze + val doubleOptimized = Optimize.execute(doubleQuery) + val doubleExpected = DeserializeToObject( + Invoke(doubleInput.output(0), "toDoubleArray", doubleObjType, Nil, true, false), + AttributeReference("obj", doubleObjType, true)(), doubleInput) + comparePlans(doubleOptimized, doubleExpected) + } +} From 702d85af2df9433254af6fa029683aa19c52a276 Mon Sep 17 00:00:00 2001 From: zero323 Date: Tue, 18 Apr 2017 19:59:18 -0700 Subject: [PATCH 303/512] [SPARK-20208][R][DOCS] Document R fpGrowth support ## What changes were proposed in this pull request? Document fpGrowth in: - vignettes - programming guide - code example ## How was this patch tested? Manual tests. Author: zero323 Closes #17557 from zero323/SPARK-20208. --- R/pkg/vignettes/sparkr-vignettes.Rmd | 37 +++++++++++++++++++- examples/src/main/r/ml/fpm.R | 50 ++++++++++++++++++++++++++++ 2 files changed, 86 insertions(+), 1 deletion(-) create mode 100644 examples/src/main/r/ml/fpm.R diff --git a/R/pkg/vignettes/sparkr-vignettes.Rmd b/R/pkg/vignettes/sparkr-vignettes.Rmd index a6ff650c33fe..f81dbab10b1e 100644 --- a/R/pkg/vignettes/sparkr-vignettes.Rmd +++ b/R/pkg/vignettes/sparkr-vignettes.Rmd @@ -505,6 +505,10 @@ SparkR supports the following machine learning models and algorithms. * Alternating Least Squares (ALS) +#### Frequent Pattern Mining + +* FP-growth + #### Statistics * Kolmogorov-Smirnov Test @@ -707,7 +711,7 @@ summary(tweedieGLM1) ``` We can try other distributions in the tweedie family, for example, a compound Poisson distribution with a log link: ```{r} -tweedieGLM2 <- spark.glm(carsDF, mpg ~ wt + hp, family = "tweedie", +tweedieGLM2 <- spark.glm(carsDF, mpg ~ wt + hp, family = "tweedie", var.power = 1.2, link.power = 0.0) summary(tweedieGLM2) ``` @@ -906,6 +910,37 @@ predicted <- predict(model, df) head(predicted) ``` +#### FP-growth + +`spark.fpGrowth` executes FP-growth algorithm to mine frequent itemsets on a `SparkDataFrame`. `itemsCol` should be an array of values. + +```{r} +df <- selectExpr(createDataFrame(data.frame(rawItems = c( + "T,R,U", "T,S", "V,R", "R,U,T,V", "R,S", "V,S,U", "U,R", "S,T", "V,R", "V,U,S", + "T,V,U", "R,V", "T,S", "T,S", "S,T", "S,U", "T,R", "V,R", "S,V", "T,S,U" +))), "split(rawItems, ',') AS items") + +fpm <- spark.fpGrowth(df, minSupport = 0.2, minConfidence = 0.5) +``` + +`spark.freqItemsets` method can be used to retrieve a `SparkDataFrame` with the frequent itemsets. + +```{r} +head(spark.freqItemsets(fpm)) +``` + +`spark.associationRules` returns a `SparkDataFrame` with the association rules. + +```{r} +head(spark.associationRules(fpm)) +``` + +We can make predictions based on the `antecedent`. + +```{r} +head(predict(fpm, df)) +``` + #### Kolmogorov-Smirnov Test `spark.kstest` runs a two-sided, one-sample [Kolmogorov-Smirnov (KS) test](https://en.wikipedia.org/wiki/Kolmogorov%E2%80%93Smirnov_test). diff --git a/examples/src/main/r/ml/fpm.R b/examples/src/main/r/ml/fpm.R new file mode 100644 index 000000000000..89c4564457d9 --- /dev/null +++ b/examples/src/main/r/ml/fpm.R @@ -0,0 +1,50 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# To run this example use +# ./bin/spark-submit examples/src/main/r/ml/fpm.R + +# Load SparkR library into your R session +library(SparkR) + +# Initialize SparkSession +sparkR.session(appName = "SparkR-ML-fpm-example") + +# $example on$ +# Load training data + +df <- selectExpr(createDataFrame(data.frame(rawItems = c( + "1,2,5", "1,2,3,5", "1,2" +))), "split(rawItems, ',') AS items") + +fpm <- spark.fpGrowth(df, itemsCol="items", minSupport=0.5, minConfidence=0.6) + +# Extracting frequent itemsets + +spark.freqItemsets(fpm) + +# Extracting association rules + +spark.associationRules(fpm) + +# Predict uses association rules to and combines possible consequents + +predict(fpm, df) + +# $example off$ + +sparkR.session.stop() From 608bf30f0b9759fd0b9b9f33766295550996a9eb Mon Sep 17 00:00:00 2001 From: Koert Kuipers Date: Wed, 19 Apr 2017 15:52:47 +0800 Subject: [PATCH 304/512] [SPARK-20359][SQL] Avoid unnecessary execution in EliminateOuterJoin optimization that can lead to NPE Avoid necessary execution that can lead to NPE in EliminateOuterJoin and add test in DataFrameSuite to confirm NPE is no longer thrown ## What changes were proposed in this pull request? Change leftHasNonNullPredicate and rightHasNonNullPredicate to lazy so they are only executed when needed. ## How was this patch tested? Added test in DataFrameSuite that failed before this fix and now succeeds. Note that a test in catalyst project would be better but i am unsure how to do this. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Koert Kuipers Closes #17660 from koertkuipers/feat-catch-npe-in-eliminate-outer-join. --- .../apache/spark/sql/catalyst/optimizer/joins.scala | 4 ++-- .../scala/org/apache/spark/sql/DataFrameSuite.scala | 10 ++++++++++ 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala index c3ab58744953..2fe303977442 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala @@ -134,8 +134,8 @@ case class EliminateOuterJoin(conf: SQLConf) extends Rule[LogicalPlan] with Pred val leftConditions = conditions.filter(_.references.subsetOf(join.left.outputSet)) val rightConditions = conditions.filter(_.references.subsetOf(join.right.outputSet)) - val leftHasNonNullPredicate = leftConditions.exists(canFilterOutNull) - val rightHasNonNullPredicate = rightConditions.exists(canFilterOutNull) + lazy val leftHasNonNullPredicate = leftConditions.exists(canFilterOutNull) + lazy val rightHasNonNullPredicate = rightConditions.exists(canFilterOutNull) join.joinType match { case RightOuter if leftHasNonNullPredicate => Inner diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 52bd4e19f895..b4893b56a8a8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -1722,4 +1722,14 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { "Cannot have map type columns in DataFrame which calls set operations")) } } + + test("SPARK-20359: catalyst outer join optimization should not throw npe") { + val df1 = Seq("a", "b", "c").toDF("x") + .withColumn("y", udf{ (x: String) => x.substring(0, 1) + "!" }.apply($"x")) + val df2 = Seq("a", "b").toDF("x1") + df1 + .join(df2, df1("x") === df2("x1"), "left_outer") + .filter($"x1".isNotNull || !$"y".isin("a!")) + .count + } } From 773754b6c1516c15b64846a00e491535cbcb1007 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 19 Apr 2017 16:01:28 +0800 Subject: [PATCH 305/512] [SPARK-20356][SQL] Pruned InMemoryTableScanExec should have correct output partitioning and ordering ## What changes were proposed in this pull request? The output of `InMemoryTableScanExec` can be pruned and mismatch with `InMemoryRelation` and its child plan's output. This causes wrong output partitioning and ordering. ## How was this patch tested? Jenkins tests. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Liang-Chi Hsieh Closes #17679 from viirya/SPARK-20356. --- .../columnar/InMemoryTableScanExec.scala | 4 +++- .../columnar/InMemoryColumnarQuerySuite.scala | 15 +++++++++++++++ 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala index 214e8d309de1..7063b08f7c64 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala @@ -42,7 +42,9 @@ case class InMemoryTableScanExec( override def output: Seq[Attribute] = attributes private def updateAttribute(expr: Expression): Expression = { - val attrMap = AttributeMap(relation.child.output.zip(output)) + // attributes can be pruned so using relation's output. + // E.g., relation.output is [id, item] but this scan's output can be [item] only. + val attrMap = AttributeMap(relation.child.output.zip(relation.output)) expr.transform { case attr: Attribute => attrMap.getOrElse(attr, attr) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala index 1e6a6a8ba336..109b1d9db60d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala @@ -414,4 +414,19 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { assert(partitionedAttrs.subsetOf(inMemoryScan.outputSet)) } } + + test("SPARK-20356: pruned InMemoryTableScanExec should have correct ordering and partitioning") { + withSQLConf("spark.sql.shuffle.partitions" -> "200") { + val df1 = Seq(("a", 1), ("b", 1), ("c", 2)).toDF("item", "group") + val df2 = Seq(("a", 1), ("b", 2), ("c", 3)).toDF("item", "id") + val df3 = df1.join(df2, Seq("item")).select($"id", $"group".as("item")).distinct() + + df3.unpersist() + val agg_without_cache = df3.groupBy($"item").count() + + df3.cache() + val agg_with_cache = df3.groupBy($"item").count() + checkAnswer(agg_without_cache, agg_with_cache) + } + } } From 35378766ad7d3c494425a8781efe9cb9349732b7 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Wed, 19 Apr 2017 12:18:54 +0100 Subject: [PATCH 306/512] [SPARK-20343][BUILD] Avoid Unidoc build only if Hadoop 2.6 is explicitly set in SBT build ## What changes were proposed in this pull request? This PR proposes two things as below: - Avoid Unidoc build only if Hadoop 2.6 is explicitly set in SBT build Due to a different dependency resolution in SBT & Unidoc by an unknown reason, the documentation build fails on a specific machine & environment in Jenkins but it was unable to reproduce. So, this PR just checks an environment variable `AMPLAB_JENKINS_BUILD_PROFILE` that is set in Hadoop 2.6 SBT build against branches on Jenkins, and then disables Unidoc build. **Note that PR builder will still build it with Hadoop 2.6 & SBT.** ``` ======================================================================== Building Unidoc API Documentation ======================================================================== [info] Building Spark unidoc (w/Hive 1.2.1) using SBT with these arguments: -Phadoop-2.6 -Pmesos -Pkinesis-asl -Pyarn -Phive-thriftserver -Phive unidoc Using /usr/java/jdk1.8.0_60 as default JAVA_HOME. ... ``` I checked the environment variables from the logs (first bit) as below: - **spark-master-test-sbt-hadoop-2.6** (this one is being failed) - https://amplab.cs.berkeley.edu/jenkins/view/Spark%20QA%20Test%20(Dashboard)/job/spark-master-test-sbt-hadoop-2.6/lastBuild/consoleFull ``` JAVA_HOME=/usr/java/jdk1.8.0_60 JAVA_7_HOME=/usr/java/jdk1.7.0_79 SPARK_BRANCH=master AMPLAB_JENKINS_BUILD_PROFILE=hadoop2.6 <- I use this variable AMPLAB_JENKINS="true" ``` - spark-master-test-sbt-hadoop-2.7 - https://amplab.cs.berkeley.edu/jenkins/view/Spark%20QA%20Test%20(Dashboard)/job/spark-master-test-sbt-hadoop-2.7/lastBuild/consoleFull ``` JAVA_HOME=/usr/java/jdk1.8.0_60 JAVA_7_HOME=/usr/java/jdk1.7.0_79 SPARK_BRANCH=master AMPLAB_JENKINS_BUILD_PROFILE=hadoop2.7 AMPLAB_JENKINS="true" ``` - spark-master-test-maven-hadoop-2.6 - https://amplab.cs.berkeley.edu/jenkins/view/Spark%20QA%20Test%20(Dashboard)/job/spark-master-test-maven-hadoop-2.6/lastBuild/consoleFull ``` JAVA_HOME=/usr/java/jdk1.8.0_60 JAVA_7_HOME=/usr/java/jdk1.7.0_79 HADOOP_PROFILE=hadoop-2.6 HADOOP_VERSION= SPARK_BRANCH=master AMPLAB_JENKINS="true" ``` - spark-master-test-maven-hadoop-2.7 - https://amplab.cs.berkeley.edu/jenkins/view/Spark%20QA%20Test%20(Dashboard)/job/spark-master-test-maven-hadoop-2.7/lastBuild/consoleFull ``` JAVA_HOME=/usr/java/jdk1.8.0_60 JAVA_7_HOME=/usr/java/jdk1.7.0_79 HADOOP_PROFILE=hadoop-2.7 HADOOP_VERSION= SPARK_BRANCH=master AMPLAB_JENKINS="true" ``` - PR builder - https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder/75843/consoleFull ``` JENKINS_MASTER_HOSTNAME=amp-jenkins-master JAVA_HOME=/usr/java/jdk1.8.0_60 JAVA_7_HOME=/usr/java/jdk1.7.0_79 ``` Assuming from other logs in branch-2.1 - SBT & Hadoop 2.6 against branch-2.1 https://amplab.cs.berkeley.edu/jenkins/view/Spark%20QA%20Test%20(Dashboard)/job/spark-branch-2.1-test-sbt-hadoop-2.6/lastBuild/consoleFull ``` JAVA_HOME=/usr/java/jdk1.8.0_60 JAVA_7_HOME=/usr/java/jdk1.7.0_79 SPARK_BRANCH=branch-2.1 AMPLAB_JENKINS_BUILD_PROFILE=hadoop2.6 AMPLAB_JENKINS="true" ``` - Maven & Hadoop 2.6 against branch-2.1 https://amplab.cs.berkeley.edu/jenkins/view/Spark%20QA%20Test%20(Dashboard)/job/spark-branch-2.1-test-maven-hadoop-2.6/lastBuild/consoleFull ``` JAVA_HOME=/usr/java/jdk1.8.0_60 JAVA_7_HOME=/usr/java/jdk1.7.0_79 HADOOP_PROFILE=hadoop-2.6 HADOOP_VERSION= SPARK_BRANCH=branch-2.1 AMPLAB_JENKINS="true" ``` We have been using the same convention for those variables. These are actually being used in `run-tests.py` script - here https://github.com/apache/spark/blob/master/dev/run-tests.py#L519-L520 - Revert the previous try After https://github.com/apache/spark/pull/17651, it seems the build still fails on SBT Hadoop 2.6 master. I am unable to reproduce this - https://github.com/apache/spark/pull/17477#issuecomment-294094092 and the reviewer was too. So, this got merged as it looks the only way to verify this is to merge it currently (as no one seems able to reproduce this). ## How was this patch tested? I only checked `is_hadoop_version_2_6 = os.environ.get("AMPLAB_JENKINS_BUILD_PROFILE") == "hadoop2.6"` is working fine as expected as below: ```python >>> import collections >>> os = collections.namedtuple('os', 'environ')(environ={"AMPLAB_JENKINS_BUILD_PROFILE": "hadoop2.6"}) >>> print(not os.environ.get("AMPLAB_JENKINS_BUILD_PROFILE") == "hadoop2.6") False >>> os = collections.namedtuple('os', 'environ')(environ={"AMPLAB_JENKINS_BUILD_PROFILE": "hadoop2.7"}) >>> print(not os.environ.get("AMPLAB_JENKINS_BUILD_PROFILE") == "hadoop2.6") True >>> os = collections.namedtuple('os', 'environ')(environ={}) >>> print(not os.environ.get("AMPLAB_JENKINS_BUILD_PROFILE") == "hadoop2.6") True ``` I tried many ways but I was unable to reproduce this in my local. Sean also tried the way I did but he was also unable to reproduce this. Please refer the comments in https://github.com/apache/spark/pull/17477#issuecomment-294094092 Author: hyukjinkwon Closes #17669 from HyukjinKwon/revert-SPARK-20343. --- dev/run-tests.py | 12 ++++++++++-- pom.xml | 1 - project/SparkBuild.scala | 14 ++------------ 3 files changed, 12 insertions(+), 15 deletions(-) diff --git a/dev/run-tests.py b/dev/run-tests.py index 450b68123e1f..818a0c9f4841 100755 --- a/dev/run-tests.py +++ b/dev/run-tests.py @@ -365,8 +365,16 @@ def build_spark_assembly_sbt(hadoop_version): print("[info] Building Spark assembly (w/Hive 1.2.1) using SBT with these arguments: ", " ".join(profiles_and_goals)) exec_sbt(profiles_and_goals) - # Make sure that Java and Scala API documentation can be generated - build_spark_unidoc_sbt(hadoop_version) + + # Note that we skip Unidoc build only if Hadoop 2.6 is explicitly set in this SBT build. + # Due to a different dependency resolution in SBT & Unidoc by an unknown reason, the + # documentation build fails on a specific machine & environment in Jenkins but it was unable + # to reproduce. Please see SPARK-20343. This is a band-aid fix that should be removed in + # the future. + is_hadoop_version_2_6 = os.environ.get("AMPLAB_JENKINS_BUILD_PROFILE") == "hadoop2.6" + if not is_hadoop_version_2_6: + # Make sure that Java and Scala API documentation can be generated + build_spark_unidoc_sbt(hadoop_version) def build_apache_spark(build_tool, hadoop_version): diff --git a/pom.xml b/pom.xml index 14370d92a908..c1174593c192 100644 --- a/pom.xml +++ b/pom.xml @@ -142,7 +142,6 @@ 2.4.0 2.0.8 3.1.2 - 1.7.7 hadoop2 0.9.3 diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 77dae289f775..e52baf51aed1 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -318,8 +318,8 @@ object SparkBuild extends PomBuild { enable(MimaBuild.mimaSettings(sparkHome, x))(x) } - /* Generate and pick the spark build info from extra-resources and override a dependency */ - enable(Core.settings ++ CoreDependencyOverrides.settings)(core) + /* Generate and pick the spark build info from extra-resources */ + enable(Core.settings)(core) /* Unsafe settings */ enable(Unsafe.settings)(unsafe) @@ -443,16 +443,6 @@ object DockerIntegrationTests { ) } -/** - * Overrides to work around sbt's dependency resolution being different from Maven's in Unidoc. - * - * Note that, this is a hack that should be removed in the future. See SPARK-20343 - */ -object CoreDependencyOverrides { - lazy val settings = Seq( - dependencyOverrides += "org.apache.avro" % "avro" % "1.7.7") -} - /** * Overrides to work around sbt's dependency resolution being different from Maven's. */ From 71a8e9df12e547cb4716f954ecb762b358f862d5 Mon Sep 17 00:00:00 2001 From: cody koeninger Date: Wed, 19 Apr 2017 18:58:58 +0100 Subject: [PATCH 307/512] [SPARK-20036][DOC] Note incompatible dependencies on org.apache.kafka artifacts ## What changes were proposed in this pull request? Note that you shouldn't manually add dependencies on org.apache.kafka artifacts ## How was this patch tested? Doc only change, did jekyll build and looked at the page. Author: cody koeninger Closes #17675 from koeninger/SPARK-20036. --- docs/streaming-kafka-0-10-integration.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/streaming-kafka-0-10-integration.md b/docs/streaming-kafka-0-10-integration.md index e3837013168d..92c296a9e6bd 100644 --- a/docs/streaming-kafka-0-10-integration.md +++ b/docs/streaming-kafka-0-10-integration.md @@ -12,6 +12,8 @@ For Scala/Java applications using SBT/Maven project definitions, link your strea artifactId = spark-streaming-kafka-0-10_{{site.SCALA_BINARY_VERSION}} version = {{site.SPARK_VERSION_SHORT}} +**Do not** manually add dependencies on `org.apache.kafka` artifacts (e.g. `kafka-clients`). The `spark-streaming-kafka-0-10` artifact has the appropriate transitive dependencies already, and different versions may be incompatible in hard to diagnose ways. + ### Creating a Direct Stream Note that the namespace for the import includes the version, org.apache.spark.streaming.kafka010 From 4fea7848c45d85ff3ad0863de5d1449d1fd1b4b0 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Wed, 19 Apr 2017 13:10:44 -0700 Subject: [PATCH 308/512] [SPARK-20397][SPARKR][SS] Fix flaky test: test_streaming.R.Terminated by error ## What changes were proposed in this pull request? Checking a source parameter is asynchronous. When the query is created, it's not guaranteed that source has been created. This PR just increases the timeout of awaitTermination to ensure the parsing error is thrown. ## How was this patch tested? Jenkins Author: Shixiong Zhu Closes #17687 from zsxwing/SPARK-20397. --- R/pkg/inst/tests/testthat/test_streaming.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/pkg/inst/tests/testthat/test_streaming.R b/R/pkg/inst/tests/testthat/test_streaming.R index 03b1bd3dc1f4..1f4054a84df5 100644 --- a/R/pkg/inst/tests/testthat/test_streaming.R +++ b/R/pkg/inst/tests/testthat/test_streaming.R @@ -131,7 +131,7 @@ test_that("Terminated by error", { expect_error(q <- write.stream(counts, "memory", queryName = "people4", outputMode = "complete"), NA) - expect_error(awaitTermination(q, 1), + expect_error(awaitTermination(q, 5 * 1000), paste0(".*(awaitTermination : streaming query error - Invalid value '-1' for option", " 'maxFilesPerTrigger', must be a positive integer).*")) From 63824b2c8e010ba03013be498def236c654d4fed Mon Sep 17 00:00:00 2001 From: ptkool Date: Thu, 20 Apr 2017 09:51:13 +0800 Subject: [PATCH 309/512] [SPARK-20350] Add optimization rules to apply Complementation Laws. ## What changes were proposed in this pull request? Apply Complementation Laws during boolean expression simplification. ## How was this patch tested? Tested using unit tests, integration tests, and manual tests. Author: ptkool Author: Michael Styles Closes #17650 from ptkool/apply_complementation_laws. --- .../sql/catalyst/optimizer/expressions.scala | 5 +++++ .../BooleanSimplificationSuite.scala | 19 +++++++++++++++++++ 2 files changed, 24 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index ea2c5d241d8d..34382bd27240 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -154,6 +154,11 @@ object BooleanSimplification extends Rule[LogicalPlan] with PredicateHelper { case TrueLiteral Or _ => TrueLiteral case _ Or TrueLiteral => TrueLiteral + case a And b if Not(a).semanticEquals(b) => FalseLiteral + case a Or b if Not(a).semanticEquals(b) => TrueLiteral + case a And b if a.semanticEquals(Not(b)) => FalseLiteral + case a Or b if a.semanticEquals(Not(b)) => TrueLiteral + case a And b if a.semanticEquals(b) => a case a Or b if a.semanticEquals(b) => a diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala index 935bff7cef2e..c275f997ba6e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.Row class BooleanSimplificationSuite extends PlanTest with PredicateHelper { @@ -42,6 +43,16 @@ class BooleanSimplificationSuite extends PlanTest with PredicateHelper { val testRelation = LocalRelation('a.int, 'b.int, 'c.int, 'd.string) + val testRelationWithData = LocalRelation.fromExternalRows( + testRelation.output, Seq(Row(1, 2, 3, "abc")) + ) + + private def checkCondition(input: Expression, expected: LogicalPlan): Unit = { + val plan = testRelationWithData.where(input).analyze + val actual = Optimize.execute(plan) + comparePlans(actual, expected) + } + private def checkCondition(input: Expression, expected: Expression): Unit = { val plan = testRelation.where(input).analyze val actual = Optimize.execute(plan) @@ -160,4 +171,12 @@ class BooleanSimplificationSuite extends PlanTest with PredicateHelper { testRelation.where('a > 2 || ('b > 3 && 'b < 5))) comparePlans(actual, expected) } + + test("Complementation Laws") { + checkCondition('a && !'a, testRelation) + checkCondition(!'a && 'a, testRelation) + + checkCondition('a || !'a, testRelationWithData) + checkCondition(!'a || 'a, testRelationWithData) + } } From 39e303a8b6db642c26dbc26ba92e87680f50e4da Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Wed, 19 Apr 2017 18:58:14 -0700 Subject: [PATCH 310/512] [MINOR][SS] Fix a missing space in UnsupportedOperationChecker error message ## What changes were proposed in this pull request? Also went through the same file to ensure other string concatenation are correct. ## How was this patch tested? Jenkins Author: Shixiong Zhu Closes #17691 from zsxwing/fix-error-message. --- .../sql/catalyst/analysis/UnsupportedOperationChecker.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala index 3f76f26dbe4e..6ab4153bac70 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala @@ -267,7 +267,7 @@ object UnsupportedOperationChecker { throwError("Limits are not supported on streaming DataFrames/Datasets") case Sort(_, _, _) if !containsCompleteData(subPlan) => - throwError("Sorting is not supported on streaming DataFrames/Datasets, unless it is on" + + throwError("Sorting is not supported on streaming DataFrames/Datasets, unless it is on " + "aggregated DataFrame/Dataset in Complete output mode") case Sample(_, _, _, _, child) if child.isStreaming => From dd6d55d5de970662eccf024e5eae4e6821373d35 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Wed, 19 Apr 2017 19:53:40 -0700 Subject: [PATCH 311/512] [SPARK-20398][SQL] range() operator should include cancellation reason when killed ## What changes were proposed in this pull request? https://issues.apache.org/jira/browse/SPARK-19820 adds a reason field for why tasks were killed. However, for backwards compatibility it left the old TaskKilledException constructor which defaults to "unknown reason". The range() operator should use the constructor that fills in the reason rather than dropping it on task kill. ## How was this patch tested? Existing tests, and I tested this manually. Author: Eric Liang Closes #17692 from ericl/fix-kill-reason-in-range. --- .../apache/spark/sql/execution/basicPhysicalOperators.scala | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index 44278e37c527..233a105f4d93 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -463,9 +463,7 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) | $number = $batchEnd; | } | - | if ($taskContext.isInterrupted()) { - | throw new TaskKilledException(); - | } + | $taskContext.killTaskIfInterrupted(); | | long $nextBatchTodo; | if ($numElementsTodo > ${batchSize}L) { From bdc60569196e9ae4e9086c3e514a406a9e8b23a6 Mon Sep 17 00:00:00 2001 From: ymahajan Date: Wed, 19 Apr 2017 20:08:31 -0700 Subject: [PATCH 312/512] Fixed typos in docs ## What changes were proposed in this pull request? Typos at a couple of place in the docs. ## How was this patch tested? build including docs Please review http://spark.apache.org/contributing.html before opening a pull request. Author: ymahajan Closes #17690 from ymahajan/master. --- docs/sql-programming-guide.md | 2 +- docs/structured-streaming-programming-guide.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 28942b68fa20..490c1ce8a7cc 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -571,7 +571,7 @@ be created by calling the `table` method on a `SparkSession` with the name of th For file-based data source, e.g. text, parquet, json, etc. you can specify a custom table path via the `path` option, e.g. `df.write.option("path", "/some/path").saveAsTable("t")`. When the table is dropped, the custom table path will not be removed and the table data is still there. If no custom table path is -specifed, Spark will write data to a default table path under the warehouse directory. When the table is +specified, Spark will write data to a default table path under the warehouse directory. When the table is dropped, the default table path will be removed too. Starting from Spark 2.1, persistent datasource tables have per-partition metadata stored in the Hive metastore. This brings several benefits: diff --git a/docs/structured-streaming-programming-guide.md b/docs/structured-streaming-programming-guide.md index 3cf7151819e2..5b18cf2f3c2e 100644 --- a/docs/structured-streaming-programming-guide.md +++ b/docs/structured-streaming-programming-guide.md @@ -778,7 +778,7 @@ windowedCounts = words \ In this example, we are defining the watermark of the query on the value of the column "timestamp", and also defining "10 minutes" as the threshold of how late is the data allowed to be. If this query is run in Update output mode (discussed later in [Output Modes](#output-modes) section), -the engine will keep updating counts of a window in the Resule Table until the window is older +the engine will keep updating counts of a window in the Result Table until the window is older than the watermark, which lags behind the current event time in column "timestamp" by 10 minutes. Here is an illustration. From 46c5749768fefd976097c7d5612ec184a4cfe1b9 Mon Sep 17 00:00:00 2001 From: zero323 Date: Wed, 19 Apr 2017 21:19:46 -0700 Subject: [PATCH 313/512] [SPARK-20375][R] R wrappers for array and map ## What changes were proposed in this pull request? Adds wrappers for `o.a.s.sql.functions.array` and `o.a.s.sql.functions.map` ## How was this patch tested? Unit tests, `check-cran.sh` Author: zero323 Closes #17674 from zero323/SPARK-20375. --- R/pkg/NAMESPACE | 2 + R/pkg/R/functions.R | 53 +++++++++++++++++++++++ R/pkg/R/generics.R | 8 ++++ R/pkg/inst/tests/testthat/test_sparkSQL.R | 17 ++++++++ 4 files changed, 80 insertions(+) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index ca45c6f9b0a9..b6b559adf06e 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -213,6 +213,8 @@ exportMethods("%in%", "count", "countDistinct", "crc32", + "create_array", + "create_map", "hash", "cume_dist", "date_add", diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index c311921fb33d..f854df11e576 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -3652,3 +3652,56 @@ setMethod("posexplode", jc <- callJStatic("org.apache.spark.sql.functions", "posexplode", x@jc) column(jc) }) + +#' create_array +#' +#' Creates a new array column. The input columns must all have the same data type. +#' +#' @param x Column to compute on +#' @param ... additional Column(s). +#' +#' @family normal_funcs +#' @rdname create_array +#' @name create_array +#' @aliases create_array,Column-method +#' @export +#' @examples \dontrun{create_array(df$x, df$y, df$z)} +#' @note create_array since 2.3.0 +setMethod("create_array", + signature(x = "Column"), + function(x, ...) { + jcols <- lapply(list(x, ...), function (x) { + stopifnot(class(x) == "Column") + x@jc + }) + jc <- callJStatic("org.apache.spark.sql.functions", "array", jcols) + column(jc) + }) + +#' create_map +#' +#' Creates a new map column. The input columns must be grouped as key-value pairs, +#' e.g. (key1, value1, key2, value2, ...). +#' The key columns must all have the same data type, and can't be null. +#' The value columns must all have the same data type. +#' +#' @param x Column to compute on +#' @param ... additional Column(s). +#' +#' @family normal_funcs +#' @rdname create_map +#' @name create_map +#' @aliases create_map,Column-method +#' @export +#' @examples \dontrun{create_map(lit("x"), lit(1.0), lit("y"), lit(-1.0))} +#' @note create_map since 2.3.0 +setMethod("create_map", + signature(x = "Column"), + function(x, ...) { + jcols <- lapply(list(x, ...), function (x) { + stopifnot(class(x) == "Column") + x@jc + }) + jc <- callJStatic("org.apache.spark.sql.functions", "map", jcols) + column(jc) + }) diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 945676c7f10b..da46823f52a1 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -942,6 +942,14 @@ setGeneric("countDistinct", function(x, ...) { standardGeneric("countDistinct") #' @export setGeneric("crc32", function(x) { standardGeneric("crc32") }) +#' @rdname create_array +#' @export +setGeneric("create_array", function(x, ...) { standardGeneric("create_array") }) + +#' @rdname create_map +#' @export +setGeneric("create_map", function(x, ...) { standardGeneric("create_map") }) + #' @rdname hash #' @export setGeneric("hash", function(x, ...) { standardGeneric("hash") }) diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index 6a6c9a809ab1..9e87a4710699 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -1461,6 +1461,23 @@ test_that("column functions", { expect_equal(length(arr$arrcol[[1]]), 2) expect_equal(arr$arrcol[[1]][[1]]$name, "Bob") expect_equal(arr$arrcol[[1]][[2]]$name, "Alice") + + # Test create_array() and create_map() + df <- as.DataFrame(data.frame( + x = c(1.0, 2.0), y = c(-1.0, 3.0), z = c(-2.0, 5.0) + )) + + arrs <- collect(select(df, create_array(df$x, df$y, df$z))) + expect_equal(arrs[, 1], list(list(1, -1, -2), list(2, 3, 5))) + + maps <- collect(select( + df, create_map(lit("x"), df$x, lit("y"), df$y, lit("z"), df$z))) + + expect_equal( + maps[, 1], + lapply( + list(list(x = 1, y = -1, z = -2), list(x = 2, y = 3, z = 5)), + as.environment)) }) test_that("column binary mathfunctions", { From 55bea56911a958f6d3ec3ad96fb425cc71ec03f4 Mon Sep 17 00:00:00 2001 From: Xiao Li Date: Thu, 20 Apr 2017 11:13:48 +0100 Subject: [PATCH 314/512] [SPARK-20156][SQL][FOLLOW-UP] Java String toLowerCase "Turkish locale bug" in Database and Table DDLs ### What changes were proposed in this pull request? Database and Table names conform the Hive standard ("[a-zA-z_0-9]+"), i.e. if this name only contains characters, numbers, and _. When calling `toLowerCase` on the names, we should add `Locale.ROOT` to the `toLowerCase`for avoiding inadvertent locale-sensitive variation in behavior (aka the "Turkish locale problem"). ### How was this patch tested? Added a test case Author: Xiao Li Closes #17655 from gatorsmile/locale. --- .../ResolveTableValuedFunctions.scala | 4 ++- .../sql/catalyst/catalog/SessionCatalog.scala | 4 +-- .../spark/sql/internal/SharedState.scala | 4 ++- .../sql/execution/command/DDLSuite.scala | 19 +++++++++++++ .../apache/spark/sql/test/SQLTestUtils.scala | 28 ++++++++++++++++++- 5 files changed, 54 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala index 8841309939c2..de6de24350f2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.analysis +import java.util.Locale + import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Range} import org.apache.spark.sql.catalyst.rules._ @@ -103,7 +105,7 @@ object ResolveTableValuedFunctions extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case u: UnresolvedTableValuedFunction if u.functionArgs.forall(_.resolved) => - builtinFunctions.get(u.functionName.toLowerCase()) match { + builtinFunctions.get(u.functionName.toLowerCase(Locale.ROOT)) match { case Some(tvf) => val resolved = tvf.flatMap { case (argList, resolver) => argList.implicitCast(u.functionArgs) match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index 3fbf83f3a38a..6c6d600190b6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -115,14 +115,14 @@ class SessionCatalog( * Format table name, taking into account case sensitivity. */ protected[this] def formatTableName(name: String): String = { - if (conf.caseSensitiveAnalysis) name else name.toLowerCase + if (conf.caseSensitiveAnalysis) name else name.toLowerCase(Locale.ROOT) } /** * Format database name, taking into account case sensitivity. */ protected[this] def formatDatabaseName(name: String): String = { - if (conf.caseSensitiveAnalysis) name else name.toLowerCase + if (conf.caseSensitiveAnalysis) name else name.toLowerCase(Locale.ROOT) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala index 0289471bf841..d06dbaa2d0ab 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.internal +import java.util.Locale + import scala.reflect.ClassTag import scala.util.control.NonFatal @@ -114,7 +116,7 @@ private[sql] class SharedState(val sparkContext: SparkContext) extends Logging { // System preserved database should not exists in metastore. However it's hard to guarantee it // for every session, because case-sensitivity differs. Here we always lowercase it to make our // life easier. - val globalTempDB = sparkContext.conf.get(GLOBAL_TEMP_DATABASE).toLowerCase + val globalTempDB = sparkContext.conf.get(GLOBAL_TEMP_DATABASE).toLowerCase(Locale.ROOT) if (externalCatalog.databaseExists(globalTempDB)) { throw new SparkException( s"$globalTempDB is a system preserved database, please rename your existing database " + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index fe74ab49f91b..2f4eb1b15519 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -2295,5 +2295,24 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } } } + + test(s"basic DDL using locale tr - caseSensitive $caseSensitive") { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> s"$caseSensitive") { + withLocale("tr") { + val dbName = "DaTaBaSe_I" + withDatabase(dbName) { + sql(s"CREATE DATABASE $dbName") + sql(s"USE $dbName") + + val tabName = "tAb_I" + withTable(tabName) { + sql(s"CREATE TABLE $tabName(col_I int) USING PARQUET") + sql(s"INSERT OVERWRITE TABLE $tabName SELECT 1") + checkAnswer(sql(s"SELECT col_I FROM $tabName"), Row(1) :: Nil) + } + } + } + } + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index 6a4cc95d36be..b5ad73b746a8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.test import java.io.File import java.net.URI import java.nio.file.Files -import java.util.UUID +import java.util.{Locale, UUID} import scala.language.implicitConversions import scala.util.control.NonFatal @@ -228,6 +228,32 @@ private[sql] trait SQLTestUtils } } + /** + * Drops database `dbName` after calling `f`. + */ + protected def withDatabase(dbNames: String*)(f: => Unit): Unit = { + try f finally { + dbNames.foreach { name => + spark.sql(s"DROP DATABASE IF EXISTS $name") + } + } + } + + /** + * Enables Locale `language` before executing `f`, then switches back to the default locale of JVM + * after `f` returns. + */ + protected def withLocale(language: String)(f: => Unit): Unit = { + val originalLocale = Locale.getDefault + try { + // Add Locale setting + Locale.setDefault(new Locale(language)) + f + } finally { + Locale.setDefault(originalLocale) + } + } + /** * Activates database `db` before executing `f`, then switches back to `default` database after * `f` returns. From c6f62c5b8106534007df31ca8c460064b89b450b Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 20 Apr 2017 14:29:59 +0200 Subject: [PATCH 315/512] [SPARK-20405][SQL] Dataset.withNewExecutionId should be private ## What changes were proposed in this pull request? Dataset.withNewExecutionId is only used in Dataset itself and should be private. ## How was this patch tested? N/A - this is a simple visibility change. Author: Reynold Xin Closes #17699 from rxin/SPARK-20405. --- sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 520663f62440..c6dcd93bbda6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -2778,7 +2778,7 @@ class Dataset[T] private[sql]( * Wrap a Dataset action to track all Spark jobs in the body so that we can connect them with * an execution. */ - private[sql] def withNewExecutionId[U](body: => U): U = { + private def withNewExecutionId[U](body: => U): U = { SQLExecution.withNewExecutionId(sparkSession, queryExecution)(body) } From b91873db0930c6fe885c27936e1243d5fabd03ed Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 20 Apr 2017 16:59:38 +0200 Subject: [PATCH 316/512] [SPARK-20409][SQL] fail early if aggregate function in GROUP BY ## What changes were proposed in this pull request? It's illegal to have aggregate function in GROUP BY, and we should fail at analysis phase, if this happens. ## How was this patch tested? new regression test Author: Wenchen Fan Closes #17704 from cloud-fan/minor. --- .../spark/sql/catalyst/analysis/Analyzer.scala | 14 ++++---------- .../sql/catalyst/analysis/CheckAnalysis.scala | 7 ++++++- .../sql-tests/results/group-by-ordinal.sql.out | 4 ++-- .../apache/spark/sql/DataFrameAggregateSuite.scala | 7 +++++++ 4 files changed, 19 insertions(+), 13 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index d9f36f7f874d..175bfb3e8085 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -966,7 +966,7 @@ class Analyzer( case p if !p.childrenResolved => p // Replace the index with the related attribute for ORDER BY, // which is a 1-base position of the projection list. - case s @ Sort(orders, global, child) + case Sort(orders, global, child) if orders.exists(_.child.isInstanceOf[UnresolvedOrdinal]) => val newOrders = orders map { case s @ SortOrder(UnresolvedOrdinal(index), direction, nullOrdering, _) => @@ -983,17 +983,11 @@ class Analyzer( // Replace the index with the corresponding expression in aggregateExpressions. The index is // a 1-base position of aggregateExpressions, which is output columns (select expression) - case a @ Aggregate(groups, aggs, child) if aggs.forall(_.resolved) && + case Aggregate(groups, aggs, child) if aggs.forall(_.resolved) && groups.exists(_.isInstanceOf[UnresolvedOrdinal]) => val newGroups = groups.map { - case ordinal @ UnresolvedOrdinal(index) if index > 0 && index <= aggs.size => - aggs(index - 1) match { - case e if ResolveAggregateFunctions.containsAggregate(e) => - ordinal.failAnalysis( - s"GROUP BY position $index is an aggregate function, and " + - "aggregate functions are not allowed in GROUP BY") - case o => o - } + case u @ UnresolvedOrdinal(index) if index > 0 && index <= aggs.size => + aggs(index - 1) case ordinal @ UnresolvedOrdinal(index) => ordinal.failAnalysis( s"GROUP BY position $index is not in select list " + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index da0c6b098f5c..61797bc34dc2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -254,6 +254,11 @@ trait CheckAnalysis extends PredicateHelper { } def checkValidGroupingExprs(expr: Expression): Unit = { + if (expr.find(_.isInstanceOf[AggregateExpression]).isDefined) { + failAnalysis( + "aggregate functions are not allowed in GROUP BY, but found " + expr.sql) + } + // Check if the data type of expr is orderable. if (!RowOrdering.isOrderable(expr.dataType)) { failAnalysis( @@ -271,8 +276,8 @@ trait CheckAnalysis extends PredicateHelper { } } - aggregateExprs.foreach(checkValidAggregateExpression) groupingExprs.foreach(checkValidGroupingExprs) + aggregateExprs.foreach(checkValidAggregateExpression) case Sort(orders, _, _) => orders.foreach { order => diff --git a/sql/core/src/test/resources/sql-tests/results/group-by-ordinal.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by-ordinal.sql.out index c0930bbde69a..d03681d0ea59 100644 --- a/sql/core/src/test/resources/sql-tests/results/group-by-ordinal.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/group-by-ordinal.sql.out @@ -122,7 +122,7 @@ select a, b, sum(b) from data group by 3 struct<> -- !query 11 output org.apache.spark.sql.AnalysisException -GROUP BY position 3 is an aggregate function, and aggregate functions are not allowed in GROUP BY; line 1 pos 39 +aggregate functions are not allowed in GROUP BY, but found sum(CAST(data.`b` AS BIGINT)); -- !query 12 @@ -131,7 +131,7 @@ select a, b, sum(b) + 2 from data group by 3 struct<> -- !query 12 output org.apache.spark.sql.AnalysisException -GROUP BY position 3 is an aggregate function, and aggregate functions are not allowed in GROUP BY; line 1 pos 43 +aggregate functions are not allowed in GROUP BY, but found (sum(CAST(data.`b` AS BIGINT)) + CAST(2 AS BIGINT)); -- !query 13 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index e7079120bb7d..8569c2d76b69 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -538,4 +538,11 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { Seq(Row(3, 0, 0.0, 1, 5.0), Row(2, 1, 4.0, 0, 0.0)) ) } + + test("aggregate function in GROUP BY") { + val e = intercept[AnalysisException] { + testData.groupBy(sum($"key")).count() + } + assert(e.message.contains("aggregate functions are not allowed in GROUP BY")) + } } From c5a31d160f47ba51bb9f8a4f3141851034640fc7 Mon Sep 17 00:00:00 2001 From: Bogdan Raducanu Date: Thu, 20 Apr 2017 18:49:39 +0200 Subject: [PATCH 317/512] [SPARK-20407][TESTS] ParquetQuerySuite 'Enabling/disabling ignoreCorruptFiles' flaky test ## What changes were proposed in this pull request? SharedSQLContext.afterEach now calls DebugFilesystem.assertNoOpenStreams inside eventually. SQLTestUtils withTempDir calls waitForTasksToFinish before deleting the directory. ## How was this patch tested? Added new test in ParquetQuerySuite based on the flaky test Author: Bogdan Raducanu Closes #17701 from bogdanrdc/SPARK-20407. --- .../parquet/ParquetQuerySuite.scala | 35 ++++++++++++++++++- .../apache/spark/sql/test/SQLTestUtils.scala | 19 ++++++++-- .../spark/sql/test/SharedSQLContext.scala | 13 ++++--- 3 files changed, 60 insertions(+), 7 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala index c36609586c80..2efff3f57d7d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala @@ -23,7 +23,7 @@ import java.sql.Timestamp import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.parquet.hadoop.ParquetOutputFormat -import org.apache.spark.SparkException +import org.apache.spark.{DebugFilesystem, SparkException} import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier} import org.apache.spark.sql.catalyst.expressions.SpecificInternalRow @@ -316,6 +316,39 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext } } + /** + * this is part of test 'Enabling/disabling ignoreCorruptFiles' but run in a loop + * to increase the chance of failure + */ + ignore("SPARK-20407 ParquetQuerySuite 'Enabling/disabling ignoreCorruptFiles' flaky test") { + def testIgnoreCorruptFiles(): Unit = { + withTempDir { dir => + val basePath = dir.getCanonicalPath + spark.range(1).toDF("a").write.parquet(new Path(basePath, "first").toString) + spark.range(1, 2).toDF("a").write.parquet(new Path(basePath, "second").toString) + spark.range(2, 3).toDF("a").write.json(new Path(basePath, "third").toString) + val df = spark.read.parquet( + new Path(basePath, "first").toString, + new Path(basePath, "second").toString, + new Path(basePath, "third").toString) + checkAnswer( + df, + Seq(Row(0), Row(1))) + } + } + + for (i <- 1 to 100) { + DebugFilesystem.clearOpenStreams() + withSQLConf(SQLConf.IGNORE_CORRUPT_FILES.key -> "false") { + val exception = intercept[SparkException] { + testIgnoreCorruptFiles() + } + assert(exception.getMessage().contains("is not a Parquet file")) + } + DebugFilesystem.assertNoOpenStreams() + } + } + test("SPARK-8990 DataFrameReader.parquet() should respect user specified options") { withTempPath { dir => val basePath = dir.getCanonicalPath diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index b5ad73b746a8..44c0fc70d066 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -22,11 +22,13 @@ import java.net.URI import java.nio.file.Files import java.util.{Locale, UUID} +import scala.concurrent.duration._ import scala.language.implicitConversions import scala.util.control.NonFatal import org.apache.hadoop.fs.Path import org.scalatest.BeforeAndAfterAll +import org.scalatest.concurrent.Eventually import org.apache.spark.SparkFunSuite import org.apache.spark.sql._ @@ -49,7 +51,7 @@ import org.apache.spark.util.{UninterruptibleThread, Utils} * prone to leaving multiple overlapping [[org.apache.spark.SparkContext]]s in the same JVM. */ private[sql] trait SQLTestUtils - extends SparkFunSuite + extends SparkFunSuite with Eventually with BeforeAndAfterAll with SQLTestData { self => @@ -138,6 +140,15 @@ private[sql] trait SQLTestUtils } } + /** + * Waits for all tasks on all executors to be finished. + */ + protected def waitForTasksToFinish(): Unit = { + eventually(timeout(10.seconds)) { + assert(spark.sparkContext.statusTracker + .getExecutorInfos.map(_.numRunningTasks()).sum == 0) + } + } /** * Creates a temporary directory, which is then passed to `f` and will be deleted after `f` * returns. @@ -146,7 +157,11 @@ private[sql] trait SQLTestUtils */ protected def withTempDir(f: File => Unit): Unit = { val dir = Utils.createTempDir().getCanonicalFile - try f(dir) finally Utils.deleteRecursively(dir) + try f(dir) finally { + // wait for all tasks to finish before deleting files + waitForTasksToFinish() + Utils.deleteRecursively(dir) + } } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala index e122b39f6fc4..3d76e05f616d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala @@ -17,17 +17,18 @@ package org.apache.spark.sql.test +import scala.concurrent.duration._ + import org.scalatest.BeforeAndAfterEach +import org.scalatest.concurrent.Eventually import org.apache.spark.{DebugFilesystem, SparkConf} import org.apache.spark.sql.{SparkSession, SQLContext} -import org.apache.spark.sql.internal.SQLConf - /** * Helper trait for SQL test suites where all tests share a single [[TestSparkSession]]. */ -trait SharedSQLContext extends SQLTestUtils with BeforeAndAfterEach { +trait SharedSQLContext extends SQLTestUtils with BeforeAndAfterEach with Eventually { protected val sparkConf = new SparkConf() @@ -84,6 +85,10 @@ trait SharedSQLContext extends SQLTestUtils with BeforeAndAfterEach { protected override def afterEach(): Unit = { super.afterEach() - DebugFilesystem.assertNoOpenStreams() + // files can be closed from other threads, so wait a bit + // normally this doesn't take more than 1s + eventually(timeout(10.seconds)) { + DebugFilesystem.assertNoOpenStreams() + } } } From b2ebadfd55283348b8a8b37e28075fca0798228a Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Thu, 20 Apr 2017 09:55:10 -0700 Subject: [PATCH 318/512] [SPARK-20358][CORE] Executors failing stage on interrupted exception thrown by cancelled tasks ## What changes were proposed in this pull request? This was a regression introduced by my earlier PR here: https://github.com/apache/spark/pull/17531 It turns out NonFatal() does not in fact catch InterruptedException. ## How was this patch tested? Extended cancellation unit test coverage. The first test fails before this patch. cc JoshRosen mridulm Author: Eric Liang Closes #17659 from ericl/spark-20358. --- .../org/apache/spark/executor/Executor.scala | 3 ++- .../org/apache/spark/SparkContextSuite.scala | 26 ++++++++++++------- 2 files changed, 19 insertions(+), 10 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 83469c5ff060..18f04391d64c 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -432,7 +432,8 @@ private[spark] class Executor( setTaskFinishedAndClearInterruptStatus() execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled(t.reason))) - case NonFatal(_) if task != null && task.reasonIfKilled.isDefined => + case _: InterruptedException | NonFatal(_) if + task != null && task.reasonIfKilled.isDefined => val killReason = task.reasonIfKilled.getOrElse("unknown reason") logInfo(s"Executor interrupted and killed $taskName (TID $taskId), reason: $killReason") setTaskFinishedAndClearInterruptStatus() diff --git a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala index 735f4454e299..7e26139a2bea 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala @@ -540,10 +540,24 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu } } - // Launches one task that will run forever. Once the SparkListener detects the task has + testCancellingTasks("that raise interrupted exception on cancel") { + Thread.sleep(9999999) + } + + // SPARK-20217 should not fail stage if task throws non-interrupted exception + testCancellingTasks("that raise runtime exception on cancel") { + try { + Thread.sleep(9999999) + } catch { + case t: Throwable => + throw new RuntimeException("killed") + } + } + + // Launches one task that will block forever. Once the SparkListener detects the task has // started, kill and re-schedule it. The second run of the task will complete immediately. // If this test times out, then the first version of the task wasn't killed successfully. - test("Killing tasks") { + def testCancellingTasks(desc: String)(blockFn: => Unit): Unit = test(s"Killing tasks $desc") { sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local")) SparkContextSuite.isTaskStarted = false @@ -572,13 +586,7 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu // first attempt will hang if (!SparkContextSuite.isTaskStarted) { SparkContextSuite.isTaskStarted = true - try { - Thread.sleep(9999999) - } catch { - case t: Throwable => - // SPARK-20217 should not fail stage if task throws non-interrupted exception - throw new RuntimeException("killed") - } + blockFn } // second attempt succeeds immediately } From d95e4d9d6a9705c534549add6d4a73d554e47274 Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Thu, 20 Apr 2017 22:35:48 +0200 Subject: [PATCH 319/512] [SPARK-20334][SQL] Return a better error message when correlated predicates contain aggregate expression that has mixture of outer and local references. ## What changes were proposed in this pull request? Address a follow up in [comment](https://github.com/apache/spark/pull/16954#discussion_r105718880) Currently subqueries with correlated predicates containing aggregate expression having mixture of outer references and local references generate a codegen error like following : ```SQL SELECT t1a FROM t1 GROUP BY 1 HAVING EXISTS (SELECT 1 FROM t2 WHERE t2a < min(t1a + t2a)); ``` Exception snippet. ``` Cannot evaluate expression: min((input[0, int, false] + input[4, int, false])) at org.apache.spark.sql.catalyst.expressions.Unevaluable$class.doGenCode(Expression.scala:226) at org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression.doGenCode(interfaces.scala:87) at org.apache.spark.sql.catalyst.expressions.Expression$$anonfun$genCode$2.apply(Expression.scala:106) at org.apache.spark.sql.catalyst.expressions.Expression$$anonfun$genCode$2.apply(Expression.scala:103) at scala.Option.getOrElse(Option.scala:121) at org.apache.spark.sql.catalyst.expressions.Expression.genCode(Expression.scala:103) ``` After this PR, a better error message is issued. ``` org.apache.spark.sql.AnalysisException Error in query: Found an aggregate expression in a correlated predicate that has both outer and local references, which is not supported yet. Aggregate expression: min((t1.`t1a` + t2.`t2a`)), Outer references: t1.`t1a`, Local references: t2.`t2a`.; ``` ## How was this patch tested? Added tests in SQLQueryTestSuite. Author: Dilip Biswal Closes #17636 from dilipbiswal/subquery_followup1. --- .../sql/catalyst/analysis/Analyzer.scala | 49 +++++++--- .../negative-cases/invalid-correlation.sql | 74 +++++++++----- .../invalid-correlation.sql.out | 96 ++++++++++++++----- .../org/apache/spark/sql/SubquerySuite.scala | 23 ++++- 4 files changed, 181 insertions(+), 61 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 175bfb3e8085..eafeb4ac1ae5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1204,6 +1204,28 @@ class Analyzer( private def checkAndGetOuterReferences(sub: LogicalPlan): Seq[Expression] = { val outerReferences = ArrayBuffer.empty[Expression] + // Validate that correlated aggregate expression do not contain a mixture + // of outer and local references. + def checkMixedReferencesInsideAggregateExpr(expr: Expression): Unit = { + expr.foreach { + case a: AggregateExpression if containsOuter(a) => + val outer = a.collect { case OuterReference(e) => e.toAttribute } + val local = a.references -- outer + if (local.nonEmpty) { + val msg = + s""" + |Found an aggregate expression in a correlated predicate that has both + |outer and local references, which is not supported yet. + |Aggregate expression: ${SubExprUtils.stripOuterReference(a).sql}, + |Outer references: ${outer.map(_.sql).mkString(", ")}, + |Local references: ${local.map(_.sql).mkString(", ")}. + """.stripMargin.replace("\n", " ").trim() + failAnalysis(msg) + } + case _ => + } + } + // Make sure a plan's subtree does not contain outer references def failOnOuterReferenceInSubTree(p: LogicalPlan): Unit = { if (hasOuterReferences(p)) { @@ -1211,9 +1233,12 @@ class Analyzer( } } - // Make sure a plan's expressions do not contain outer references - def failOnOuterReference(p: LogicalPlan): Unit = { - if (p.expressions.exists(containsOuter)) { + // Make sure a plan's expressions do not contain : + // 1. Aggregate expressions that have mixture of outer and local references. + // 2. Expressions containing outer references on plan nodes other than Filter. + def failOnInvalidOuterReference(p: LogicalPlan): Unit = { + p.expressions.foreach(checkMixedReferencesInsideAggregateExpr) + if (!p.isInstanceOf[Filter] && p.expressions.exists(containsOuter)) { failAnalysis( "Expressions referencing the outer query are not supported outside of WHERE/HAVING " + s"clauses:\n$p") @@ -1283,9 +1308,9 @@ class Analyzer( // These operators can be anywhere in a correlated subquery. // so long as they do not host outer references in the operators. case s: Sort => - failOnOuterReference(s) + failOnInvalidOuterReference(s) case r: RepartitionByExpression => - failOnOuterReference(r) + failOnInvalidOuterReference(r) // Category 3: // Filter is one of the two operators allowed to host correlated expressions. @@ -1299,6 +1324,8 @@ class Analyzer( case _: EqualTo | _: EqualNullSafe => false case _ => true } + + failOnInvalidOuterReference(f) // The aggregate expressions are treated in a special way by getOuterReferences. If the // aggregate expression contains only outer reference attributes then the entire aggregate // expression is isolated as an OuterReference. @@ -1308,7 +1335,7 @@ class Analyzer( // Project cannot host any correlated expressions // but can be anywhere in a correlated subquery. case p: Project => - failOnOuterReference(p) + failOnInvalidOuterReference(p) // Aggregate cannot host any correlated expressions // It can be on a correlation path if the correlation contains @@ -1316,7 +1343,7 @@ class Analyzer( // It cannot be on a correlation path if the correlation has // non-equality correlated predicates. case a: Aggregate => - failOnOuterReference(a) + failOnInvalidOuterReference(a) failOnNonEqualCorrelatedPredicate(foundNonEqualCorrelatedPred, a) // Join can host correlated expressions. @@ -1324,7 +1351,7 @@ class Analyzer( joinType match { // Inner join, like Filter, can be anywhere. case _: InnerLike => - failOnOuterReference(j) + failOnInvalidOuterReference(j) // Left outer join's right operand cannot be on a correlation path. // LeftAnti and ExistenceJoin are special cases of LeftOuter. @@ -1335,12 +1362,12 @@ class Analyzer( // Any correlated references in the subplan // of the right operand cannot be pulled up. case LeftOuter | LeftSemi | LeftAnti | ExistenceJoin(_) => - failOnOuterReference(j) + failOnInvalidOuterReference(j) failOnOuterReferenceInSubTree(right) // Likewise, Right outer join's left operand cannot be on a correlation path. case RightOuter => - failOnOuterReference(j) + failOnInvalidOuterReference(j) failOnOuterReferenceInSubTree(left) // Any other join types not explicitly listed above, @@ -1356,7 +1383,7 @@ class Analyzer( // Note: // Generator with join=false is treated as Category 4. case g: Generate if g.join => - failOnOuterReference(g) + failOnInvalidOuterReference(g) // Category 4: Any other operators not in the above 3 categories // cannot be on a correlation path, that is they are allowed only diff --git a/sql/core/src/test/resources/sql-tests/inputs/subquery/negative-cases/invalid-correlation.sql b/sql/core/src/test/resources/sql-tests/inputs/subquery/negative-cases/invalid-correlation.sql index cf93c5a83597..e22cade93679 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/subquery/negative-cases/invalid-correlation.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/subquery/negative-cases/invalid-correlation.sql @@ -1,42 +1,72 @@ -- The test file contains negative test cases -- of invalid queries where error messages are expected. -create temporary view t1 as select * from values +CREATE TEMPORARY VIEW t1 AS SELECT * FROM VALUES (1, 2, 3) -as t1(t1a, t1b, t1c); +AS t1(t1a, t1b, t1c); -create temporary view t2 as select * from values +CREATE TEMPORARY VIEW t2 AS SELECT * FROM VALUES (1, 0, 1) -as t2(t2a, t2b, t2c); +AS t2(t2a, t2b, t2c); -create temporary view t3 as select * from values +CREATE TEMPORARY VIEW t3 AS SELECT * FROM VALUES (3, 1, 2) -as t3(t3a, t3b, t3c); +AS t3(t3a, t3b, t3c); -- TC 01.01 -- The column t2b in the SELECT of the subquery is invalid -- because it is neither an aggregate function nor a GROUP BY column. -select t1a, t2b -from t1, t2 -where t1b = t2c -and t2b = (select max(avg) - from (select t2b, avg(t2b) avg - from t2 - where t2a = t1.t1b +SELECT t1a, t2b +FROM t1, t2 +WHERE t1b = t2c +AND t2b = (SELECT max(avg) + FROM (SELECT t2b, avg(t2b) avg + FROM t2 + WHERE t2a = t1.t1b ) ) ; -- TC 01.02 -- Invalid due to the column t2b not part of the output from table t2. -select * -from t1 -where t1a in (select min(t2a) - from t2 - group by t2c - having t2c in (select max(t3c) - from t3 - group by t3b - having t3b > t2b )) +SELECT * +FROM t1 +WHERE t1a IN (SELECT min(t2a) + FROM t2 + GROUP BY t2c + HAVING t2c IN (SELECT max(t3c) + FROM t3 + GROUP BY t3b + HAVING t3b > t2b )) ; +-- TC 01.03 +-- Invalid due to mixure of outer and local references under an AggegatedExpression +-- in a correlated predicate +SELECT t1a +FROM t1 +GROUP BY 1 +HAVING EXISTS (SELECT 1 + FROM t2 + WHERE t2a < min(t1a + t2a)); + +-- TC 01.04 +-- Invalid due to mixure of outer and local references under an AggegatedExpression +SELECT t1a +FROM t1 +WHERE t1a IN (SELECT t2a + FROM t2 + WHERE EXISTS (SELECT 1 + FROM t3 + GROUP BY 1 + HAVING min(t2a + t3a) > 1)); + +-- TC 01.05 +-- Invalid due to outer reference appearing in projection list +SELECT t1a +FROM t1 +WHERE t1a IN (SELECT t2a + FROM t2 + WHERE EXISTS (SELECT min(t2a) + FROM t3)); + diff --git a/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/invalid-correlation.sql.out b/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/invalid-correlation.sql.out index f7bbb35aad6c..e4b1a2dbc675 100644 --- a/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/invalid-correlation.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/invalid-correlation.sql.out @@ -1,11 +1,11 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 5 +-- Number of queries: 8 -- !query 0 -create temporary view t1 as select * from values +CREATE TEMPORARY VIEW t1 AS SELECT * FROM VALUES (1, 2, 3) -as t1(t1a, t1b, t1c) +AS t1(t1a, t1b, t1c) -- !query 0 schema struct<> -- !query 0 output @@ -13,9 +13,9 @@ struct<> -- !query 1 -create temporary view t2 as select * from values +CREATE TEMPORARY VIEW t2 AS SELECT * FROM VALUES (1, 0, 1) -as t2(t2a, t2b, t2c) +AS t2(t2a, t2b, t2c) -- !query 1 schema struct<> -- !query 1 output @@ -23,9 +23,9 @@ struct<> -- !query 2 -create temporary view t3 as select * from values +CREATE TEMPORARY VIEW t3 AS SELECT * FROM VALUES (3, 1, 2) -as t3(t3a, t3b, t3c) +AS t3(t3a, t3b, t3c) -- !query 2 schema struct<> -- !query 2 output @@ -33,13 +33,13 @@ struct<> -- !query 3 -select t1a, t2b -from t1, t2 -where t1b = t2c -and t2b = (select max(avg) - from (select t2b, avg(t2b) avg - from t2 - where t2a = t1.t1b +SELECT t1a, t2b +FROM t1, t2 +WHERE t1b = t2c +AND t2b = (SELECT max(avg) + FROM (SELECT t2b, avg(t2b) avg + FROM t2 + WHERE t2a = t1.t1b ) ) -- !query 3 schema @@ -50,17 +50,67 @@ grouping expressions sequence is empty, and 't2.`t2b`' is not an aggregate funct -- !query 4 -select * -from t1 -where t1a in (select min(t2a) - from t2 - group by t2c - having t2c in (select max(t3c) - from t3 - group by t3b - having t3b > t2b )) +SELECT * +FROM t1 +WHERE t1a IN (SELECT min(t2a) + FROM t2 + GROUP BY t2c + HAVING t2c IN (SELECT max(t3c) + FROM t3 + GROUP BY t3b + HAVING t3b > t2b )) -- !query 4 schema struct<> -- !query 4 output org.apache.spark.sql.AnalysisException resolved attribute(s) t2b#x missing from min(t2a)#x,t2c#x in operator !Filter t2c#x IN (list#x [t2b#x]); + + +-- !query 5 +SELECT t1a +FROM t1 +GROUP BY 1 +HAVING EXISTS (SELECT 1 + FROM t2 + WHERE t2a < min(t1a + t2a)) +-- !query 5 schema +struct<> +-- !query 5 output +org.apache.spark.sql.AnalysisException +Found an aggregate expression in a correlated predicate that has both outer and local references, which is not supported yet. Aggregate expression: min((t1.`t1a` + t2.`t2a`)), Outer references: t1.`t1a`, Local references: t2.`t2a`.; + + +-- !query 6 +SELECT t1a +FROM t1 +WHERE t1a IN (SELECT t2a + FROM t2 + WHERE EXISTS (SELECT 1 + FROM t3 + GROUP BY 1 + HAVING min(t2a + t3a) > 1)) +-- !query 6 schema +struct<> +-- !query 6 output +org.apache.spark.sql.AnalysisException +Found an aggregate expression in a correlated predicate that has both outer and local references, which is not supported yet. Aggregate expression: min((t2.`t2a` + t3.`t3a`)), Outer references: t2.`t2a`, Local references: t3.`t3a`.; + + +-- !query 7 +SELECT t1a +FROM t1 +WHERE t1a IN (SELECT t2a + FROM t2 + WHERE EXISTS (SELECT min(t2a) + FROM t3)) +-- !query 7 schema +struct<> +-- !query 7 output +org.apache.spark.sql.AnalysisException +Expressions referencing the outer query are not supported outside of WHERE/HAVING clauses: +Aggregate [min(outer(t2a#x)) AS min(outer())#x] ++- SubqueryAlias t3 + +- Project [t3a#x, t3b#x, t3c#x] + +- SubqueryAlias t3 + +- LocalRelation [t3a#x, t3b#x, t3c#x] +; diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala index 0f0199cbe277..131abf7c1e5d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala @@ -822,12 +822,25 @@ class SubquerySuite extends QueryTest with SharedSQLContext { checkAnswer( sql( """ - | select c2 - | from t1 - | where exists (select * - | from t2 lateral view explode(arr_c2) q as c2 - where t1.c1 = t2.c1)""".stripMargin), + | SELECT c2 + | FROM t1 + | WHERE EXISTS (SELECT * + | FROM t2 LATERAL VIEW explode(arr_c2) q AS c2 + WHERE t1.c1 = t2.c1)""".stripMargin), Row(1) :: Row(0) :: Nil) + + val msg1 = intercept[AnalysisException] { + sql( + """ + | SELECT c1 + | FROM t2 + | WHERE EXISTS (SELECT * + | FROM t1 LATERAL VIEW explode(t2.arr_c2) q AS c2 + | WHERE t1.c1 = t2.c1) + """.stripMargin) + } + assert(msg1.getMessage.contains( + "Expressions referencing the outer query are not supported outside of WHERE/HAVING")) } } From 033206355339677812a250b2b64818a261871fd2 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Thu, 20 Apr 2017 22:37:04 +0200 Subject: [PATCH 320/512] [SPARK-20410][SQL] Make sparkConf a def in SharedSQLContext ## What changes were proposed in this pull request? It is kind of annoying that `SharedSQLContext.sparkConf` is a val when overriding test cases, because you cannot call `super` on it. This PR makes it a function. ## How was this patch tested? Existing tests. Author: Herman van Hovell Closes #17705 from hvanhovell/SPARK-20410. --- .../spark/sql/AggregateHashMapSuite.scala | 35 ++++++++----------- .../DatasetSerializerRegistratorSuite.scala | 12 +++---- .../DataSourceScanExecRedactionSuite.scala | 11 ++---- .../datasources/FileSourceStrategySuite.scala | 2 +- .../CompactibleFileStreamLogSuite.scala | 4 +-- .../streaming/HDFSMetadataLogSuite.scala | 4 +-- .../spark/sql/test/SharedSQLContext.scala | 7 ++-- 7 files changed, 32 insertions(+), 43 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/AggregateHashMapSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/AggregateHashMapSuite.scala index 3e85d9552312..7e61a6802515 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/AggregateHashMapSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/AggregateHashMapSuite.scala @@ -19,13 +19,12 @@ package org.apache.spark.sql import org.scalatest.BeforeAndAfter -class SingleLevelAggregateHashMapSuite extends DataFrameAggregateSuite with BeforeAndAfter { +import org.apache.spark.SparkConf - protected override def beforeAll(): Unit = { - sparkConf.set("spark.sql.codegen.fallback", "false") - sparkConf.set("spark.sql.codegen.aggregate.map.twolevel.enable", "false") - super.beforeAll() - } +class SingleLevelAggregateHashMapSuite extends DataFrameAggregateSuite with BeforeAndAfter { + override protected def sparkConf: SparkConf = super.sparkConf + .set("spark.sql.codegen.fallback", "false") + .set("spark.sql.codegen.aggregate.map.twolevel.enable", "false") // adding some checking after each test is run, assuring that the configs are not changed // in test code @@ -38,12 +37,9 @@ class SingleLevelAggregateHashMapSuite extends DataFrameAggregateSuite with Befo } class TwoLevelAggregateHashMapSuite extends DataFrameAggregateSuite with BeforeAndAfter { - - protected override def beforeAll(): Unit = { - sparkConf.set("spark.sql.codegen.fallback", "false") - sparkConf.set("spark.sql.codegen.aggregate.map.twolevel.enable", "true") - super.beforeAll() - } + override protected def sparkConf: SparkConf = super.sparkConf + .set("spark.sql.codegen.fallback", "false") + .set("spark.sql.codegen.aggregate.map.twolevel.enable", "true") // adding some checking after each test is run, assuring that the configs are not changed // in test code @@ -55,15 +51,14 @@ class TwoLevelAggregateHashMapSuite extends DataFrameAggregateSuite with BeforeA } } -class TwoLevelAggregateHashMapWithVectorizedMapSuite extends DataFrameAggregateSuite with -BeforeAndAfter { +class TwoLevelAggregateHashMapWithVectorizedMapSuite + extends DataFrameAggregateSuite + with BeforeAndAfter { - protected override def beforeAll(): Unit = { - sparkConf.set("spark.sql.codegen.fallback", "false") - sparkConf.set("spark.sql.codegen.aggregate.map.twolevel.enable", "true") - sparkConf.set("spark.sql.codegen.aggregate.map.vectorized.enable", "true") - super.beforeAll() - } + override protected def sparkConf: SparkConf = super.sparkConf + .set("spark.sql.codegen.fallback", "false") + .set("spark.sql.codegen.aggregate.map.twolevel.enable", "true") + .set("spark.sql.codegen.aggregate.map.vectorized.enable", "true") // adding some checking after each test is run, assuring that the configs are not changed // in test code diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSerializerRegistratorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSerializerRegistratorSuite.scala index 92c5656f65bb..68f7de047b39 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSerializerRegistratorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSerializerRegistratorSuite.scala @@ -20,9 +20,9 @@ package org.apache.spark.sql import com.esotericsoftware.kryo.{Kryo, Serializer} import com.esotericsoftware.kryo.io.{Input, Output} +import org.apache.spark.SparkConf import org.apache.spark.serializer.KryoRegistrator import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.sql.test.TestSparkSession /** * Test suite to test Kryo custom registrators. @@ -30,12 +30,10 @@ import org.apache.spark.sql.test.TestSparkSession class DatasetSerializerRegistratorSuite extends QueryTest with SharedSQLContext { import testImplicits._ - /** - * Initialize the [[TestSparkSession]] with a [[KryoRegistrator]]. - */ - protected override def beforeAll(): Unit = { - sparkConf.set("spark.kryo.registrator", TestRegistrator().getClass.getCanonicalName) - super.beforeAll() + + override protected def sparkConf: SparkConf = { + // Make sure we use the KryoRegistrator + super.sparkConf.set("spark.kryo.registrator", TestRegistrator().getClass.getCanonicalName) } test("Kryo registrator") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/DataSourceScanExecRedactionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/DataSourceScanExecRedactionSuite.scala index 05a2b2c862c7..f7f1ccea281c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/DataSourceScanExecRedactionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/DataSourceScanExecRedactionSuite.scala @@ -18,22 +18,17 @@ package org.apache.spark.sql.execution import org.apache.hadoop.fs.Path +import org.apache.spark.SparkConf import org.apache.spark.sql.QueryTest import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.util.Utils /** * Suite that tests the redaction of DataSourceScanExec */ class DataSourceScanExecRedactionSuite extends QueryTest with SharedSQLContext { - import Utils._ - - override def beforeAll(): Unit = { - sparkConf.set("spark.redaction.string.regex", - "file:/[\\w_]+") - super.beforeAll() - } + override protected def sparkConf: SparkConf = super.sparkConf + .set("spark.redaction.string.regex", "file:/[\\w_]+") test("treeString is redacted") { withTempDir { dir => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala index f36162858bf7..8703fe96e587 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala @@ -42,7 +42,7 @@ import org.apache.spark.util.Utils class FileSourceStrategySuite extends QueryTest with SharedSQLContext with PredicateHelper { import testImplicits._ - protected override val sparkConf = new SparkConf().set("spark.default.parallelism", "1") + protected override def sparkConf = super.sparkConf.set("spark.default.parallelism", "1") test("unpartitioned table, single partition") { val table = diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLogSuite.scala index 20ac06f048c6..3d480b148db5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLogSuite.scala @@ -28,8 +28,8 @@ import org.apache.spark.sql.test.SharedSQLContext class CompactibleFileStreamLogSuite extends SparkFunSuite with SharedSQLContext { /** To avoid caching of FS objects */ - override protected val sparkConf = - new SparkConf().set(s"spark.hadoop.fs.$scheme.impl.disable.cache", "true") + override protected def sparkConf = + super.sparkConf.set(s"spark.hadoop.fs.$scheme.impl.disable.cache", "true") import CompactibleFileStreamLog._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala index 662c4466b21b..7689bc03a4cc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala @@ -38,8 +38,8 @@ import org.apache.spark.util.UninterruptibleThread class HDFSMetadataLogSuite extends SparkFunSuite with SharedSQLContext { /** To avoid caching of FS objects */ - override protected val sparkConf = - new SparkConf().set(s"spark.hadoop.fs.$scheme.impl.disable.cache", "true") + override protected def sparkConf = + super.sparkConf.set(s"spark.hadoop.fs.$scheme.impl.disable.cache", "true") private implicit def toOption[A](a: A): Option[A] = Option(a) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala index 3d76e05f616d..81c69a338abc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala @@ -30,7 +30,9 @@ import org.apache.spark.sql.{SparkSession, SQLContext} */ trait SharedSQLContext extends SQLTestUtils with BeforeAndAfterEach with Eventually { - protected val sparkConf = new SparkConf() + protected def sparkConf = { + new SparkConf().set("spark.hadoop.fs.file.impl", classOf[DebugFilesystem].getName) + } /** * The [[TestSparkSession]] to use for all tests in this suite. @@ -51,8 +53,7 @@ trait SharedSQLContext extends SQLTestUtils with BeforeAndAfterEach with Eventua protected implicit def sqlContext: SQLContext = _spark.sqlContext protected def createSparkSession: TestSparkSession = { - new TestSparkSession( - sparkConf.set("spark.hadoop.fs.file.impl", classOf[DebugFilesystem].getName)) + new TestSparkSession(sparkConf) } /** From 592f5c89349f3c5b6ec0531c6514b8f7d95ad8da Mon Sep 17 00:00:00 2001 From: jerryshao Date: Thu, 20 Apr 2017 16:02:09 -0700 Subject: [PATCH 321/512] [SPARK-20172][CORE] Add file permission check when listing files in FsHistoryProvider ## What changes were proposed in this pull request? In the current Spark's HistoryServer we expected to get `AccessControlException` during listing all the files, but unfortunately it was not worked because we actually doesn't check the access permission and no other calls will throw such exception. What was worse is that this check will be deferred until reading files, which is not necessary and quite verbose, since it will be printed out the exception in every 10 seconds when checking the files. So here with this fix, we actually check the read permission during listing the files, which could avoid unnecessary file read later on and suppress the verbose log. ## How was this patch tested? Add unit test to verify. Author: jerryshao Closes #17495 from jerryshao/SPARK-20172. --- .../apache/spark/deploy/SparkHadoopUtil.scala | 23 +++++ .../deploy/history/FsHistoryProvider.scala | 28 +++--- .../spark/deploy/SparkHadoopUtilSuite.scala | 97 +++++++++++++++++++ .../history/FsHistoryProviderSuite.scala | 16 ++- 4 files changed, 145 insertions(+), 19 deletions(-) create mode 100644 core/src/test/scala/org/apache/spark/deploy/SparkHadoopUtilSuite.scala diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala index bae7a3f307f5..9cc321af4bde 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala @@ -28,6 +28,7 @@ import scala.util.control.NonFatal import com.google.common.primitives.Longs import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, FileSystem, Path, PathFilter} +import org.apache.hadoop.fs.permission.FsAction import org.apache.hadoop.mapred.JobConf import org.apache.hadoop.security.{Credentials, UserGroupInformation} import org.apache.hadoop.security.token.{Token, TokenIdentifier} @@ -353,6 +354,28 @@ class SparkHadoopUtil extends Logging { } buffer.toString } + + private[spark] def checkAccessPermission(status: FileStatus, mode: FsAction): Boolean = { + val perm = status.getPermission + val ugi = UserGroupInformation.getCurrentUser + + if (ugi.getShortUserName == status.getOwner) { + if (perm.getUserAction.implies(mode)) { + return true + } + } else if (ugi.getGroupNames.contains(status.getGroup)) { + if (perm.getGroupAction.implies(mode)) { + return true + } + } else if (perm.getOtherAction.implies(mode)) { + return true + } + + logDebug(s"Permission denied: user=${ugi.getShortUserName}, " + + s"path=${status.getPath}:${status.getOwner}:${status.getGroup}" + + s"${if (status.isDirectory) "d" else "-"}$perm") + false + } } object SparkHadoopUtil { diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index 9012736bc274..f4235df24512 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -27,7 +27,8 @@ import scala.xml.Node import com.google.common.io.ByteStreams import com.google.common.util.concurrent.{MoreExecutors, ThreadFactoryBuilder} -import org.apache.hadoop.fs.{FileStatus, FileSystem, Path} +import org.apache.hadoop.fs.{FileStatus, Path} +import org.apache.hadoop.fs.permission.FsAction import org.apache.hadoop.hdfs.DistributedFileSystem import org.apache.hadoop.hdfs.protocol.HdfsConstants import org.apache.hadoop.security.AccessControlException @@ -318,21 +319,14 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) // scan for modified applications, replay and merge them val logInfos: Seq[FileStatus] = statusList .filter { entry => - try { - val prevFileSize = fileToAppInfo.get(entry.getPath()).map{_.fileSize}.getOrElse(0L) - !entry.isDirectory() && - // FsHistoryProvider generates a hidden file which can't be read. Accidentally - // reading a garbage file is safe, but we would log an error which can be scary to - // the end-user. - !entry.getPath().getName().startsWith(".") && - prevFileSize < entry.getLen() - } catch { - case e: AccessControlException => - // Do not use "logInfo" since these messages can get pretty noisy if printed on - // every poll. - logDebug(s"No permission to read $entry, ignoring.") - false - } + val prevFileSize = fileToAppInfo.get(entry.getPath()).map{_.fileSize}.getOrElse(0L) + !entry.isDirectory() && + // FsHistoryProvider generates a hidden file which can't be read. Accidentally + // reading a garbage file is safe, but we would log an error which can be scary to + // the end-user. + !entry.getPath().getName().startsWith(".") && + prevFileSize < entry.getLen() && + SparkHadoopUtil.get.checkAccessPermission(entry, FsAction.READ) } .flatMap { entry => Some(entry) } .sortWith { case (entry1, entry2) => @@ -445,7 +439,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) /** * Replay the log files in the list and merge the list of old applications with new ones */ - private def mergeApplicationListing(fileStatus: FileStatus): Unit = { + protected def mergeApplicationListing(fileStatus: FileStatus): Unit = { val newAttempts = try { val eventsFilter: ReplayEventsFilter = { eventString => eventString.startsWith(APPL_START_EVENT_PREFIX) || diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkHadoopUtilSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkHadoopUtilSuite.scala new file mode 100644 index 000000000000..ab24a76e20a3 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/deploy/SparkHadoopUtilSuite.scala @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy + +import java.security.PrivilegedExceptionAction + +import scala.util.Random + +import org.apache.hadoop.fs.FileStatus +import org.apache.hadoop.fs.permission.{FsAction, FsPermission} +import org.apache.hadoop.security.UserGroupInformation +import org.scalatest.Matchers + +import org.apache.spark.SparkFunSuite + +class SparkHadoopUtilSuite extends SparkFunSuite with Matchers { + test("check file permission") { + import FsAction._ + val testUser = s"user-${Random.nextInt(100)}" + val testGroups = Array(s"group-${Random.nextInt(100)}") + val testUgi = UserGroupInformation.createUserForTesting(testUser, testGroups) + + testUgi.doAs(new PrivilegedExceptionAction[Void] { + override def run(): Void = { + val sparkHadoopUtil = new SparkHadoopUtil + + // If file is owned by user and user has access permission + var status = fileStatus(testUser, testGroups.head, READ_WRITE, READ_WRITE, NONE) + sparkHadoopUtil.checkAccessPermission(status, READ) should be(true) + sparkHadoopUtil.checkAccessPermission(status, WRITE) should be(true) + + // If file is owned by user but user has no access permission + status = fileStatus(testUser, testGroups.head, NONE, READ_WRITE, NONE) + sparkHadoopUtil.checkAccessPermission(status, READ) should be(false) + sparkHadoopUtil.checkAccessPermission(status, WRITE) should be(false) + + val otherUser = s"test-${Random.nextInt(100)}" + val otherGroup = s"test-${Random.nextInt(100)}" + + // If file is owned by user's group and user's group has access permission + status = fileStatus(otherUser, testGroups.head, NONE, READ_WRITE, NONE) + sparkHadoopUtil.checkAccessPermission(status, READ) should be(true) + sparkHadoopUtil.checkAccessPermission(status, WRITE) should be(true) + + // If file is owned by user's group but user's group has no access permission + status = fileStatus(otherUser, testGroups.head, READ_WRITE, NONE, NONE) + sparkHadoopUtil.checkAccessPermission(status, READ) should be(false) + sparkHadoopUtil.checkAccessPermission(status, WRITE) should be(false) + + // If file is owned by other user and this user has access permission + status = fileStatus(otherUser, otherGroup, READ_WRITE, READ_WRITE, READ_WRITE) + sparkHadoopUtil.checkAccessPermission(status, READ) should be(true) + sparkHadoopUtil.checkAccessPermission(status, WRITE) should be(true) + + // If file is owned by other user but this user has no access permission + status = fileStatus(otherUser, otherGroup, READ_WRITE, READ_WRITE, NONE) + sparkHadoopUtil.checkAccessPermission(status, READ) should be(false) + sparkHadoopUtil.checkAccessPermission(status, WRITE) should be(false) + + null + } + }) + } + + private def fileStatus( + owner: String, + group: String, + userAction: FsAction, + groupAction: FsAction, + otherAction: FsAction): FileStatus = { + new FileStatus(0L, + false, + 0, + 0L, + 0L, + 0L, + new FsPermission(userAction, groupAction, otherAction), + owner, + group, + null) + } +} diff --git a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala index ec580a44b8e7..456158d41b93 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala @@ -27,6 +27,7 @@ import scala.concurrent.duration._ import scala.language.postfixOps import com.google.common.io.{ByteStreams, Files} +import org.apache.hadoop.fs.FileStatus import org.apache.hadoop.hdfs.DistributedFileSystem import org.json4s.jackson.JsonMethods._ import org.mockito.Matchers.any @@ -130,9 +131,19 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc } } - test("SPARK-3697: ignore directories that cannot be read.") { + test("SPARK-3697: ignore files that cannot be read.") { // setReadable(...) does not work on Windows. Please refer JDK-6728842. assume(!Utils.isWindows) + + class TestFsHistoryProvider extends FsHistoryProvider(createTestConf()) { + var mergeApplicationListingCall = 0 + override protected def mergeApplicationListing(fileStatus: FileStatus): Unit = { + super.mergeApplicationListing(fileStatus) + mergeApplicationListingCall += 1 + } + } + val provider = new TestFsHistoryProvider + val logFile1 = newLogFile("new1", None, inProgress = false) writeFile(logFile1, true, None, SparkListenerApplicationStart("app1-1", Some("app1-1"), 1L, "test", None), @@ -145,10 +156,11 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc ) logFile2.setReadable(false, false) - val provider = new FsHistoryProvider(createTestConf()) updateAndCheck(provider) { list => list.size should be (1) } + + provider.mergeApplicationListingCall should be (1) } test("history file is renamed from inprogress to completed") { From 0368eb9d86634c83b3140ce3190cb9e0d0b7fd86 Mon Sep 17 00:00:00 2001 From: Juliusz Sompolski Date: Fri, 21 Apr 2017 09:49:42 +0800 Subject: [PATCH 322/512] [SPARK-20367] Properly unescape column names of partitioning columns parsed from paths. ## What changes were proposed in this pull request? When infering partitioning schema from paths, the column in parsePartitionColumn should be unescaped with unescapePathName, just like it is being done in e.g. parsePathFragmentAsSeq. ## How was this patch tested? Added a test to FileIndexSuite. Author: Juliusz Sompolski Closes #17703 from juliuszsompolski/SPARK-20367. --- .../execution/datasources/PartitioningUtils.scala | 2 +- .../sql/execution/datasources/FileIndexSuite.scala | 12 ++++++++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala index c3583209efc5..2d70172487e1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala @@ -243,7 +243,7 @@ object PartitioningUtils { if (equalSignIndex == -1) { None } else { - val columnName = columnSpec.take(equalSignIndex) + val columnName = unescapePathName(columnSpec.take(equalSignIndex)) assert(columnName.nonEmpty, s"Empty partition column name in '$columnSpec'") val rawColumnValue = columnSpec.drop(equalSignIndex + 1) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileIndexSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileIndexSuite.scala index a9511cbd9e4c..b4616826e40b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileIndexSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileIndexSuite.scala @@ -27,6 +27,7 @@ import org.apache.hadoop.fs.{FileStatus, Path, RawLocalFileSystem} import org.apache.spark.metrics.source.HiveCatalogMetrics import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.functions.col import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.util.{KnownSizeEstimation, SizeEstimator} @@ -236,6 +237,17 @@ class FileIndexSuite extends SharedSQLContext { val fileStatusCache = FileStatusCache.getOrCreate(spark) fileStatusCache.putLeafFiles(new Path("/tmp", "abc"), files.toArray) } + + test("SPARK-20367 - properly unescape column names in inferPartitioning") { + withTempPath { path => + val colToUnescape = "Column/#%'?" + spark + .range(1) + .select(col("id").as(colToUnescape), col("id")) + .write.partitionBy(colToUnescape).parquet(path.getAbsolutePath) + assert(spark.read.parquet(path.getAbsolutePath).schema.exists(_.name == colToUnescape)) + } + } } class FakeParentPathFileSystem extends RawLocalFileSystem { From 760c8d088df1d35d7b8942177d47bc1677daf143 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Fri, 21 Apr 2017 10:06:12 +0800 Subject: [PATCH 323/512] [SPARK-20329][SQL] Make timezone aware expression without timezone unresolved ## What changes were proposed in this pull request? A cast expression with a resolved time zone is not equal to a cast expression without a resolved time zone. The `ResolveAggregateFunction` assumed that these expression were the same, and would fail to resolve `HAVING` clauses which contain a `Cast` expression. This is in essence caused by the fact that a `TimeZoneAwareExpression` can be resolved without a set time zone. This PR fixes this, and makes a `TimeZoneAwareExpression` unresolved as long as it has no TimeZone set. ## How was this patch tested? Added a regression test to the `SQLQueryTestSuite.having` file. Author: Herman van Hovell Closes #17641 from hvanhovell/SPARK-20329. --- .../sql/catalyst/analysis/Analyzer.scala | 20 +----- .../analysis/ResolveInlineTables.scala | 10 +-- .../catalyst/analysis/timeZoneAnalysis.scala | 61 +++++++++++++++++++ .../spark/sql/catalyst/analysis/view.scala | 4 +- .../expressions/datetimeExpressions.scala | 4 +- .../analysis/ResolveInlineTablesSuite.scala | 10 +-- .../catalyst/analysis/TypeCoercionSuite.scala | 35 ++++++----- .../sql/catalyst/expressions/CastSuite.scala | 4 +- .../expressions/DateExpressionsSuite.scala | 6 +- .../expressions/ExpressionEvalHelper.scala | 7 ++- .../spark/sql/execution/SparkPlanner.scala | 2 +- .../datasources/DataSourceStrategy.scala | 20 +++--- .../sql/execution/datasources/rules.scala | 6 +- .../internal/BaseSessionStateBuilder.scala | 2 +- .../resources/sql-tests/inputs/having.sql | 3 + .../sql-tests/results/having.sql.out | 11 +++- .../spark/sql/sources/BucketedReadSuite.scala | 3 +- .../sql/sources/DataSourceAnalysisSuite.scala | 16 +++-- .../sql/hive/HiveSessionStateBuilder.scala | 2 +- 19 files changed, 148 insertions(+), 78 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/timeZoneAnalysis.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index eafeb4ac1ae5..dcadbbc90f43 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -150,6 +150,7 @@ class Analyzer( ResolveAggregateFunctions :: TimeWindowing :: ResolveInlineTables(conf) :: + ResolveTimeZone(conf) :: TypeCoercion.typeCoercionRules ++ extendedResolutionRules : _*), Batch("Post-Hoc Resolution", Once, postHocResolutionRules: _*), @@ -161,8 +162,6 @@ class Analyzer( HandleNullInputsForUDF), Batch("FixNullability", Once, FixNullability), - Batch("ResolveTimeZone", Once, - ResolveTimeZone), Batch("Subquery", Once, UpdateOuterReferences), Batch("Cleanup", fixedPoint, @@ -2368,23 +2367,6 @@ class Analyzer( } } } - - /** - * Replace [[TimeZoneAwareExpression]] without timezone id by its copy with session local - * time zone. - */ - object ResolveTimeZone extends Rule[LogicalPlan] { - - override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveExpressions { - case e: TimeZoneAwareExpression if e.timeZoneId.isEmpty => - e.withTimeZone(conf.sessionLocalTimeZone) - // Casts could be added in the subquery plan through the rule TypeCoercion while coercing - // the types between the value expression and list query expression of IN expression. - // We need to subject the subquery plan through ResolveTimeZone again to setup timezone - // information for time zone aware expressions. - case e: ListQuery => e.withNewPlan(apply(e.plan)) - } - } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala index a991dd96e282..f2df3e132629 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql.catalyst.analysis import scala.util.control.NonFatal import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Cast, TimeZoneAwareExpression} import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.internal.SQLConf @@ -29,7 +28,7 @@ import org.apache.spark.sql.types.{StructField, StructType} /** * An analyzer rule that replaces [[UnresolvedInlineTable]] with [[LocalRelation]]. */ -case class ResolveInlineTables(conf: SQLConf) extends Rule[LogicalPlan] { +case class ResolveInlineTables(conf: SQLConf) extends Rule[LogicalPlan] with CastSupport { override def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case table: UnresolvedInlineTable if table.expressionsResolved => validateInputDimension(table) @@ -99,12 +98,9 @@ case class ResolveInlineTables(conf: SQLConf) extends Rule[LogicalPlan] { val castedExpr = if (e.dataType.sameType(targetType)) { e } else { - Cast(e, targetType) + cast(e, targetType) } - castedExpr.transform { - case e: TimeZoneAwareExpression if e.timeZoneId.isEmpty => - e.withTimeZone(conf.sessionLocalTimeZone) - }.eval() + castedExpr.eval() } catch { case NonFatal(ex) => table.failAnalysis(s"failed to evaluate expression ${e.sql}: ${ex.getMessage}") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/timeZoneAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/timeZoneAnalysis.scala new file mode 100644 index 000000000000..a27aa845bf0a --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/timeZoneAnalysis.scala @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.catalyst.expressions.{Cast, Expression, ListQuery, TimeZoneAwareExpression} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.DataType + +/** + * Replace [[TimeZoneAwareExpression]] without timezone id by its copy with session local + * time zone. + */ +case class ResolveTimeZone(conf: SQLConf) extends Rule[LogicalPlan] { + private val transformTimeZoneExprs: PartialFunction[Expression, Expression] = { + case e: TimeZoneAwareExpression if e.timeZoneId.isEmpty => + e.withTimeZone(conf.sessionLocalTimeZone) + // Casts could be added in the subquery plan through the rule TypeCoercion while coercing + // the types between the value expression and list query expression of IN expression. + // We need to subject the subquery plan through ResolveTimeZone again to setup timezone + // information for time zone aware expressions. + case e: ListQuery => e.withNewPlan(apply(e.plan)) + } + + override def apply(plan: LogicalPlan): LogicalPlan = + plan.resolveExpressions(transformTimeZoneExprs) + + def resolveTimeZones(e: Expression): Expression = e.transform(transformTimeZoneExprs) +} + +/** + * Mix-in trait for constructing valid [[Cast]] expressions. + */ +trait CastSupport { + /** + * Configuration used to create a valid cast expression. + */ + def conf: SQLConf + + /** + * Create a Cast expression with the session local time zone. + */ + def cast(child: Expression, dataType: DataType): Cast = { + Cast(child, dataType, Option(conf.sessionLocalTimeZone)) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/view.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/view.scala index 3bd54c257d98..ea46dd728240 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/view.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/view.scala @@ -47,7 +47,7 @@ import org.apache.spark.sql.internal.SQLConf * This should be only done after the batch of Resolution, because the view attributes are not * completely resolved during the batch of Resolution. */ -case class AliasViewChild(conf: SQLConf) extends Rule[LogicalPlan] { +case class AliasViewChild(conf: SQLConf) extends Rule[LogicalPlan] with CastSupport { override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case v @ View(desc, output, child) if child.resolved && output != child.output => val resolver = conf.resolver @@ -78,7 +78,7 @@ case class AliasViewChild(conf: SQLConf) extends Rule[LogicalPlan] { throw new AnalysisException(s"Cannot up cast ${originAttr.sql} from " + s"${originAttr.dataType.simpleString} to ${attr.simpleString} as it may truncate\n") } else { - Alias(Cast(originAttr, attr.dataType), attr.name)(exprId = attr.exprId, + Alias(cast(originAttr, attr.dataType), attr.name)(exprId = attr.exprId, qualifier = attr.qualifier, explicitMetadata = Some(attr.metadata)) } case (_, originAttr) => originAttr diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index f8fe774823e5..bb8fd5032d63 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -24,7 +24,6 @@ import java.util.{Calendar, TimeZone} import scala.util.control.NonFatal import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodegenFallback, ExprCode} import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ @@ -34,6 +33,9 @@ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} * Common base class for time zone aware expressions. */ trait TimeZoneAwareExpression extends Expression { + /** The expression is only resolved when the time zone has been set. */ + override lazy val resolved: Boolean = + childrenResolved && checkInputDataTypes().isSuccess && timeZoneId.isDefined /** the timezone ID to be used to evaluate value. */ def timeZoneId: Option[String] diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTablesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTablesSuite.scala index f45a82686984..d0fe81505225 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTablesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTablesSuite.scala @@ -22,6 +22,7 @@ import org.scalatest.BeforeAndAfter import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions.{Cast, Literal, Rand} import org.apache.spark.sql.catalyst.expressions.aggregate.Count +import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.types.{LongType, NullType, TimestampType} /** @@ -91,12 +92,13 @@ class ResolveInlineTablesSuite extends AnalysisTest with BeforeAndAfter { test("convert TimeZoneAwareExpression") { val table = UnresolvedInlineTable(Seq("c1"), Seq(Seq(Cast(lit("1991-12-06 00:00:00.0"), TimestampType)))) - val converted = ResolveInlineTables(conf).convert(table) + val withTimeZone = ResolveTimeZone(conf).apply(table) + val LocalRelation(output, data) = ResolveInlineTables(conf).apply(withTimeZone) val correct = Cast(lit("1991-12-06 00:00:00.0"), TimestampType) .withTimeZone(conf.sessionLocalTimeZone).eval().asInstanceOf[Long] - assert(converted.output.map(_.dataType) == Seq(TimestampType)) - assert(converted.data.size == 1) - assert(converted.data(0).getLong(0) == correct) + assert(output.map(_.dataType) == Seq(TimestampType)) + assert(data.size == 1) + assert(data.head.getLong(0) == correct) } test("nullability inference in convert") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala index 011d09ff6064..2624f5586fd5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.CalendarInterval @@ -787,6 +788,12 @@ class TypeCoercionSuite extends PlanTest { } } + private val timeZoneResolver = ResolveTimeZone(new SQLConf) + + private def widenSetOperationTypes(plan: LogicalPlan): LogicalPlan = { + timeZoneResolver(TypeCoercion.WidenSetOperationTypes(plan)) + } + test("WidenSetOperationTypes for except and intersect") { val firstTable = LocalRelation( AttributeReference("i", IntegerType)(), @@ -799,11 +806,10 @@ class TypeCoercionSuite extends PlanTest { AttributeReference("f", FloatType)(), AttributeReference("l", LongType)()) - val wt = TypeCoercion.WidenSetOperationTypes val expectedTypes = Seq(StringType, DecimalType.SYSTEM_DEFAULT, FloatType, DoubleType) - val r1 = wt(Except(firstTable, secondTable)).asInstanceOf[Except] - val r2 = wt(Intersect(firstTable, secondTable)).asInstanceOf[Intersect] + val r1 = widenSetOperationTypes(Except(firstTable, secondTable)).asInstanceOf[Except] + val r2 = widenSetOperationTypes(Intersect(firstTable, secondTable)).asInstanceOf[Intersect] checkOutput(r1.left, expectedTypes) checkOutput(r1.right, expectedTypes) checkOutput(r2.left, expectedTypes) @@ -838,10 +844,9 @@ class TypeCoercionSuite extends PlanTest { AttributeReference("p", ByteType)(), AttributeReference("q", DoubleType)()) - val wt = TypeCoercion.WidenSetOperationTypes val expectedTypes = Seq(StringType, DecimalType.SYSTEM_DEFAULT, FloatType, DoubleType) - val unionRelation = wt( + val unionRelation = widenSetOperationTypes( Union(firstTable :: secondTable :: thirdTable :: forthTable :: Nil)).asInstanceOf[Union] assert(unionRelation.children.length == 4) checkOutput(unionRelation.children.head, expectedTypes) @@ -862,17 +867,15 @@ class TypeCoercionSuite extends PlanTest { } } - val dp = TypeCoercion.WidenSetOperationTypes - val left1 = LocalRelation( AttributeReference("l", DecimalType(10, 8))()) val right1 = LocalRelation( AttributeReference("r", DecimalType(5, 5))()) val expectedType1 = Seq(DecimalType(10, 8)) - val r1 = dp(Union(left1, right1)).asInstanceOf[Union] - val r2 = dp(Except(left1, right1)).asInstanceOf[Except] - val r3 = dp(Intersect(left1, right1)).asInstanceOf[Intersect] + val r1 = widenSetOperationTypes(Union(left1, right1)).asInstanceOf[Union] + val r2 = widenSetOperationTypes(Except(left1, right1)).asInstanceOf[Except] + val r3 = widenSetOperationTypes(Intersect(left1, right1)).asInstanceOf[Intersect] checkOutput(r1.children.head, expectedType1) checkOutput(r1.children.last, expectedType1) @@ -891,17 +894,17 @@ class TypeCoercionSuite extends PlanTest { val plan2 = LocalRelation( AttributeReference("r", rType)()) - val r1 = dp(Union(plan1, plan2)).asInstanceOf[Union] - val r2 = dp(Except(plan1, plan2)).asInstanceOf[Except] - val r3 = dp(Intersect(plan1, plan2)).asInstanceOf[Intersect] + val r1 = widenSetOperationTypes(Union(plan1, plan2)).asInstanceOf[Union] + val r2 = widenSetOperationTypes(Except(plan1, plan2)).asInstanceOf[Except] + val r3 = widenSetOperationTypes(Intersect(plan1, plan2)).asInstanceOf[Intersect] checkOutput(r1.children.last, Seq(expectedType)) checkOutput(r2.right, Seq(expectedType)) checkOutput(r3.right, Seq(expectedType)) - val r4 = dp(Union(plan2, plan1)).asInstanceOf[Union] - val r5 = dp(Except(plan2, plan1)).asInstanceOf[Except] - val r6 = dp(Intersect(plan2, plan1)).asInstanceOf[Intersect] + val r4 = widenSetOperationTypes(Union(plan2, plan1)).asInstanceOf[Union] + val r5 = widenSetOperationTypes(Except(plan2, plan1)).asInstanceOf[Except] + val r6 = widenSetOperationTypes(Intersect(plan2, plan1)).asInstanceOf[Intersect] checkOutput(r4.children.last, Seq(expectedType)) checkOutput(r5.left, Seq(expectedType)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index a7ffa884d228..22f3f3514fa4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -34,7 +34,7 @@ import org.apache.spark.unsafe.types.UTF8String */ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { - private def cast(v: Any, targetType: DataType, timeZoneId: Option[String] = None): Cast = { + private def cast(v: Any, targetType: DataType, timeZoneId: Option[String] = Some("GMT")): Cast = { v match { case lit: Expression => Cast(lit, targetType, timeZoneId) case _ => Cast(Literal(v), targetType, timeZoneId) @@ -47,7 +47,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { } private def checkNullCast(from: DataType, to: DataType): Unit = { - checkEvaluation(cast(Literal.create(null, from), to, Option("GMT")), null) + checkEvaluation(cast(Literal.create(null, from), to), null) } test("null cast") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala index 9978f35a0381..ca89bf7db0b4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala @@ -160,7 +160,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { test("Seconds") { assert(Second(Literal.create(null, DateType), gmtId).resolved === false) - assert(Second(Cast(Literal(d), TimestampType), None).resolved === true) + assert(Second(Cast(Literal(d), TimestampType, gmtId), gmtId).resolved === true) checkEvaluation(Second(Cast(Literal(d), TimestampType, gmtId), gmtId), 0) checkEvaluation(Second(Cast(Literal(sdf.format(d)), TimestampType, gmtId), gmtId), 15) checkEvaluation(Second(Literal(ts), gmtId), 15) @@ -220,7 +220,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { test("Hour") { assert(Hour(Literal.create(null, DateType), gmtId).resolved === false) - assert(Hour(Literal(ts), None).resolved === true) + assert(Hour(Literal(ts), gmtId).resolved === true) checkEvaluation(Hour(Cast(Literal(d), TimestampType, gmtId), gmtId), 0) checkEvaluation(Hour(Cast(Literal(sdf.format(d)), TimestampType, gmtId), gmtId), 13) checkEvaluation(Hour(Literal(ts), gmtId), 13) @@ -246,7 +246,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { test("Minute") { assert(Minute(Literal.create(null, DateType), gmtId).resolved === false) - assert(Minute(Literal(ts), None).resolved === true) + assert(Minute(Literal(ts), gmtId).resolved === true) checkEvaluation(Minute(Cast(Literal(d), TimestampType, gmtId), gmtId), 0) checkEvaluation( Minute(Cast(Literal(sdf.format(d)), TimestampType, gmtId), gmtId), 10) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index 1ba6dd1c5e8c..b6399edb68dd 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -25,10 +25,12 @@ import org.scalatest.prop.GeneratorDrivenPropertyChecks import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.serializer.JavaSerializer import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} +import org.apache.spark.sql.catalyst.analysis.ResolveTimeZone import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.optimizer.SimpleTestOptimizer import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project} -import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData} +import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -45,7 +47,8 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { protected def checkEvaluation( expression: => Expression, expected: Any, inputRow: InternalRow = EmptyRow): Unit = { val serializer = new JavaSerializer(new SparkConf()).newInstance - val expr: Expression = serializer.deserialize(serializer.serialize(expression)) + val resolver = ResolveTimeZone(new SQLConf) + val expr = resolver.resolveTimeZones(serializer.deserialize(serializer.serialize(expression))) val catalystValue = CatalystTypeConverters.convertToCatalyst(expected) checkEvaluationWithoutCodegen(expr, catalystValue, inputRow) checkEvaluationWithGeneratedMutableProjection(expr, catalystValue, inputRow) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala index 6566502bd8a8..4e718d609c92 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala @@ -36,7 +36,7 @@ class SparkPlanner( experimentalMethods.extraStrategies ++ extraPlanningStrategies ++ ( FileSourceStrategy :: - DataSourceStrategy :: + DataSourceStrategy(conf) :: SpecialLimits :: Aggregation :: JoinSelection :: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 2d83d512e702..d307122b5c70 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -24,7 +24,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, QualifiedTableName, TableIdentifier} +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, QualifiedTableName} import org.apache.spark.sql.catalyst.CatalystTypeConverters.convertToScala import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.catalog.{CatalogRelation, CatalogUtils} @@ -48,7 +48,7 @@ import org.apache.spark.unsafe.types.UTF8String * Note that, this rule must be run after `PreprocessTableCreation` and * `PreprocessTableInsertion`. */ -case class DataSourceAnalysis(conf: SQLConf) extends Rule[LogicalPlan] { +case class DataSourceAnalysis(conf: SQLConf) extends Rule[LogicalPlan] with CastSupport { def resolver: Resolver = conf.resolver @@ -98,11 +98,11 @@ case class DataSourceAnalysis(conf: SQLConf) extends Rule[LogicalPlan] { val potentialSpecs = staticPartitions.filter { case (partKey, partValue) => resolver(field.name, partKey) } - if (potentialSpecs.size == 0) { + if (potentialSpecs.isEmpty) { None } else if (potentialSpecs.size == 1) { val partValue = potentialSpecs.head._2 - Some(Alias(Cast(Literal(partValue), field.dataType), field.name)()) + Some(Alias(cast(Literal(partValue), field.dataType), field.name)()) } else { throw new AnalysisException( s"Partition column ${field.name} have multiple values specified, " + @@ -258,7 +258,9 @@ class FindDataSourceTable(sparkSession: SparkSession) extends Rule[LogicalPlan] /** * A Strategy for planning scans over data sources defined using the sources API. */ -object DataSourceStrategy extends Strategy with Logging { +case class DataSourceStrategy(conf: SQLConf) extends Strategy with Logging with CastSupport { + import DataSourceStrategy._ + def apply(plan: LogicalPlan): Seq[execution.SparkPlan] = plan match { case PhysicalOperation(projects, filters, l @ LogicalRelation(t: CatalystScan, _, _)) => pruneFilterProjectRaw( @@ -298,7 +300,7 @@ object DataSourceStrategy extends Strategy with Logging { // Restriction: Bucket pruning works iff the bucketing column has one and only one column. def getBucketId(bucketColumn: Attribute, numBuckets: Int, value: Any): Int = { val mutableRow = new SpecificInternalRow(Seq(bucketColumn.dataType)) - mutableRow(0) = Cast(Literal(value), bucketColumn.dataType).eval(null) + mutableRow(0) = cast(Literal(value), bucketColumn.dataType).eval(null) val bucketIdGeneration = UnsafeProjection.create( HashPartitioning(bucketColumn :: Nil, numBuckets).partitionIdExpression :: Nil, bucketColumn :: Nil) @@ -436,7 +438,9 @@ object DataSourceStrategy extends Strategy with Logging { private[this] def toCatalystRDD(relation: LogicalRelation, rdd: RDD[Row]): RDD[InternalRow] = { toCatalystRDD(relation, relation.output, rdd) } +} +object DataSourceStrategy { /** * Tries to translate a Catalyst [[Expression]] into data source [[Filter]]. * @@ -527,8 +531,8 @@ object DataSourceStrategy extends Strategy with Logging { * all [[Filter]]s that are completely filtered at the DataSource. */ protected[sql] def selectFilters( - relation: BaseRelation, - predicates: Seq[Expression]): (Seq[Expression], Seq[Filter], Set[Filter]) = { + relation: BaseRelation, + predicates: Seq[Expression]): (Seq[Expression], Seq[Filter], Set[Filter]) = { // For conciseness, all Catalyst filter expressions of type `expressions.Expression` below are // called `predicate`s, while all data source filters of type `sources.Filter` are simply called diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala index 7abf2ae5166b..3f4a78580f1e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala @@ -22,7 +22,7 @@ import java.util.Locale import org.apache.spark.sql.{AnalysisException, SaveMode, SparkSession} import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.catalog._ -import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Cast, RowOrdering} +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, RowOrdering} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.command.DDLUtils @@ -315,7 +315,7 @@ case class PreprocessTableCreation(sparkSession: SparkSession) extends Rule[Logi * table. It also does data type casting and field renaming, to make sure that the columns to be * inserted have the correct data type and fields have the correct names. */ -case class PreprocessTableInsertion(conf: SQLConf) extends Rule[LogicalPlan] { +case class PreprocessTableInsertion(conf: SQLConf) extends Rule[LogicalPlan] with CastSupport { private def preprocess( insert: InsertIntoTable, tblName: String, @@ -367,7 +367,7 @@ case class PreprocessTableInsertion(conf: SQLConf) extends Rule[LogicalPlan] { // Renaming is needed for handling the following cases like // 1) Column names/types do not match, e.g., INSERT INTO TABLE tab1 SELECT 1, 2 // 2) Target tables have column metadata - Alias(Cast(actual, expected.dataType), expected.name)( + Alias(cast(actual, expected.dataType), expected.name)( explicitMetadata = Option(expected.metadata)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala index 2b14eca919fa..df7c3678b780 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.internal import org.apache.spark.SparkConf import org.apache.spark.annotation.{Experimental, InterfaceStability} import org.apache.spark.sql.{ExperimentalMethods, SparkSession, Strategy, UDFRegistration} -import org.apache.spark.sql.catalyst.analysis.{Analyzer, FunctionRegistry} +import org.apache.spark.sql.catalyst.analysis.{Analyzer, FunctionRegistry, ResolveTimeZone} import org.apache.spark.sql.catalyst.catalog.SessionCatalog import org.apache.spark.sql.catalyst.optimizer.Optimizer import org.apache.spark.sql.catalyst.parser.ParserInterface diff --git a/sql/core/src/test/resources/sql-tests/inputs/having.sql b/sql/core/src/test/resources/sql-tests/inputs/having.sql index 364c022d959d..868a911e787f 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/having.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/having.sql @@ -13,3 +13,6 @@ SELECT count(k) FROM hav GROUP BY v + 1 HAVING v + 1 = 2; -- SPARK-11032: resolve having correctly SELECT MIN(t.v) FROM (SELECT * FROM hav WHERE v > 0) t HAVING(COUNT(1) > 0); + +-- SPARK-20329: make sure we handle timezones correctly +SELECT a + b FROM VALUES (1L, 2), (3L, 4) AS T(a, b) GROUP BY a + b HAVING a + b > 1; diff --git a/sql/core/src/test/resources/sql-tests/results/having.sql.out b/sql/core/src/test/resources/sql-tests/results/having.sql.out index e0923832673c..d87ee5221647 100644 --- a/sql/core/src/test/resources/sql-tests/results/having.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/having.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 4 +-- Number of queries: 5 -- !query 0 @@ -38,3 +38,12 @@ SELECT MIN(t.v) FROM (SELECT * FROM hav WHERE v > 0) t HAVING(COUNT(1) > 0) struct -- !query 3 output 1 + + +-- !query 4 +SELECT a + b FROM VALUES (1L, 2), (3L, 4) AS T(a, b) GROUP BY a + b HAVING a + b > 1 +-- !query 4 schema +struct<(a + CAST(b AS BIGINT)):bigint> +-- !query 4 output +3 +7 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala index 9b65419dba23..ba0ca666b5c1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala @@ -90,6 +90,7 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils { originalDataFrame: DataFrame): Unit = { // This test verifies parts of the plan. Disable whole stage codegen. withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false") { + val strategy = DataSourceStrategy(spark.sessionState.conf) val bucketedDataFrame = spark.table("bucketed_table").select("i", "j", "k") val BucketSpec(numBuckets, bucketColumnNames, _) = bucketSpec // Limit: bucket pruning only works when the bucket column has one and only one column @@ -98,7 +99,7 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils { val bucketColumn = bucketedDataFrame.schema.toAttributes(bucketColumnIndex) val matchedBuckets = new BitSet(numBuckets) bucketValues.foreach { value => - matchedBuckets.set(DataSourceStrategy.getBucketId(bucketColumn, numBuckets, value)) + matchedBuckets.set(strategy.getBucketId(bucketColumn, numBuckets, value)) } // Filter could hide the bug in bucket pruning. Thus, skipping all the filters diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceAnalysisSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceAnalysisSuite.scala index b16c9f8fc96b..735e07c21373 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceAnalysisSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceAnalysisSuite.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Cast, Expression, Literal} import org.apache.spark.sql.execution.datasources.DataSourceAnalysis import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{IntegerType, StructType} +import org.apache.spark.sql.types.{DataType, IntegerType, StructType} class DataSourceAnalysisSuite extends SparkFunSuite with BeforeAndAfterAll { @@ -49,7 +49,11 @@ class DataSourceAnalysisSuite extends SparkFunSuite with BeforeAndAfterAll { } Seq(true, false).foreach { caseSensitive => - val rule = DataSourceAnalysis(new SQLConf().copy(SQLConf.CASE_SENSITIVE -> caseSensitive)) + val conf = new SQLConf().copy(SQLConf.CASE_SENSITIVE -> caseSensitive) + def cast(e: Expression, dt: DataType): Expression = { + Cast(e, dt, Option(conf.sessionLocalTimeZone)) + } + val rule = DataSourceAnalysis(conf) test( s"convertStaticPartitions only handle INSERT having at least static partitions " + s"(caseSensitive: $caseSensitive)") { @@ -150,7 +154,7 @@ class DataSourceAnalysisSuite extends SparkFunSuite with BeforeAndAfterAll { if (!caseSensitive) { val nonPartitionedAttributes = Seq('e.int, 'f.int) val expected = nonPartitionedAttributes ++ - Seq(Cast(Literal("1"), IntegerType), Cast(Literal("3"), IntegerType)) + Seq(cast(Literal("1"), IntegerType), cast(Literal("3"), IntegerType)) val actual = rule.convertStaticPartitions( sourceAttributes = nonPartitionedAttributes, providedPartitions = Map("b" -> Some("1"), "C" -> Some("3")), @@ -162,7 +166,7 @@ class DataSourceAnalysisSuite extends SparkFunSuite with BeforeAndAfterAll { { val nonPartitionedAttributes = Seq('e.int, 'f.int) val expected = nonPartitionedAttributes ++ - Seq(Cast(Literal("1"), IntegerType), Cast(Literal("3"), IntegerType)) + Seq(cast(Literal("1"), IntegerType), cast(Literal("3"), IntegerType)) val actual = rule.convertStaticPartitions( sourceAttributes = nonPartitionedAttributes, providedPartitions = Map("b" -> Some("1"), "c" -> Some("3")), @@ -174,7 +178,7 @@ class DataSourceAnalysisSuite extends SparkFunSuite with BeforeAndAfterAll { // Test the case having a single static partition column. { val nonPartitionedAttributes = Seq('e.int, 'f.int) - val expected = nonPartitionedAttributes ++ Seq(Cast(Literal("1"), IntegerType)) + val expected = nonPartitionedAttributes ++ Seq(cast(Literal("1"), IntegerType)) val actual = rule.convertStaticPartitions( sourceAttributes = nonPartitionedAttributes, providedPartitions = Map("b" -> Some("1")), @@ -189,7 +193,7 @@ class DataSourceAnalysisSuite extends SparkFunSuite with BeforeAndAfterAll { val dynamicPartitionAttributes = Seq('g.int) val expected = nonPartitionedAttributes ++ - Seq(Cast(Literal("1"), IntegerType)) ++ + Seq(cast(Literal("1"), IntegerType)) ++ dynamicPartitionAttributes val actual = rule.convertStaticPartitions( sourceAttributes = nonPartitionedAttributes ++ dynamicPartitionAttributes, diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala index 9d3b31f39c0f..e16c9e46b772 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala @@ -101,7 +101,7 @@ class HiveSessionStateBuilder(session: SparkSession, parentState: Option[Session experimentalMethods.extraStrategies ++ extraPlanningStrategies ++ Seq( FileSourceStrategy, - DataSourceStrategy, + DataSourceStrategy(conf), SpecialLimits, InMemoryScans, HiveTableScans, From 48d760d028dd73371f99d084c4195dbc4dda5267 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Thu, 20 Apr 2017 19:40:21 -0700 Subject: [PATCH 324/512] [SPARK-20281][SQL] Print the identical Range parameters of SparkContext APIs and SQL in explain ## What changes were proposed in this pull request? This pr modified code to print the identical `Range` parameters of SparkContext APIs and SQL in `explain` output. In the current master, they internally use `defaultParallelism` for `splits` by default though, they print different strings in explain output; ``` scala> spark.range(4).explain == Physical Plan == *Range (0, 4, step=1, splits=Some(8)) scala> sql("select * from range(4)").explain == Physical Plan == *Range (0, 4, step=1, splits=None) ``` ## How was this patch tested? Added tests in `SQLQuerySuite` and modified some results in the existing tests. Author: Takeshi Yamamuro Closes #17670 from maropu/SPARK-20281. --- .../apache/spark/sql/execution/basicPhysicalOperators.scala | 3 ++- .../sql-tests/results/sql-compatibility-functions.sql.out | 2 +- .../resources/sql-tests/results/table-valued-functions.sql.out | 2 +- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index 233a105f4d93..d3efa428a6db 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -332,6 +332,7 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) extends LeafExecNode with CodegenSupport { def start: Long = range.start + def end: Long = range.end def step: Long = range.step def numSlices: Int = range.numSlices.getOrElse(sparkContext.defaultParallelism) def numElements: BigInt = range.numElements @@ -538,7 +539,7 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) } } - override def simpleString: String = range.simpleString + override def simpleString: String = s"Range ($start, $end, step=$step, splits=$numSlices)" } /** diff --git a/sql/core/src/test/resources/sql-tests/results/sql-compatibility-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/sql-compatibility-functions.sql.out index 9f0b95994be5..732b11050f46 100644 --- a/sql/core/src/test/resources/sql-tests/results/sql-compatibility-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/sql-compatibility-functions.sql.out @@ -88,7 +88,7 @@ Project [coalesce(cast(id#xL as string), x) AS ifnull(`id`, 'x')#x, id#xL AS nul == Physical Plan == *Project [coalesce(cast(id#xL as string), x) AS ifnull(`id`, 'x')#x, id#xL AS nullif(`id`, 'x')#xL, coalesce(cast(id#xL as string), x) AS nvl(`id`, 'x')#x, x AS nvl2(`id`, 'x', 'y')#x] -+- *Range (0, 2, step=1, splits=None) ++- *Range (0, 2, step=1, splits=2) -- !query 9 diff --git a/sql/core/src/test/resources/sql-tests/results/table-valued-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/table-valued-functions.sql.out index acd4ecf14617..e2ee970d35f6 100644 --- a/sql/core/src/test/resources/sql-tests/results/table-valued-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/table-valued-functions.sql.out @@ -102,4 +102,4 @@ EXPLAIN select * from RaNgE(2) struct -- !query 8 output == Physical Plan == -*Range (0, 2, step=1, splits=None) +*Range (0, 2, step=1, splits=2) From e2b3d2367a563d4600d8d87b5317e71135c362f0 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Fri, 21 Apr 2017 00:05:03 -0700 Subject: [PATCH 325/512] [SPARK-20420][SQL] Add events to the external catalog ## What changes were proposed in this pull request? It is often useful to be able to track changes to the `ExternalCatalog`. This PR makes the `ExternalCatalog` emit events when a catalog object is changed. Events are fired before and after the change. The following events are fired per object: - Database - CreateDatabasePreEvent: event fired before the database is created. - CreateDatabaseEvent: event fired after the database has been created. - DropDatabasePreEvent: event fired before the database is dropped. - DropDatabaseEvent: event fired after the database has been dropped. - Table - CreateTablePreEvent: event fired before the table is created. - CreateTableEvent: event fired after the table has been created. - RenameTablePreEvent: event fired before the table is renamed. - RenameTableEvent: event fired after the table has been renamed. - DropTablePreEvent: event fired before the table is dropped. - DropTableEvent: event fired after the table has been dropped. - Function - CreateFunctionPreEvent: event fired before the function is created. - CreateFunctionEvent: event fired after the function has been created. - RenameFunctionPreEvent: event fired before the function is renamed. - RenameFunctionEvent: event fired after the function has been renamed. - DropFunctionPreEvent: event fired before the function is dropped. - DropFunctionPreEvent: event fired after the function has been dropped. The current events currently only contain the names of the object modified. We add more events, and more details at a later point. A user can monitor changes to the external catalog by adding a listener to the Spark listener bus checking for `ExternalCatalogEvent`s using the `SparkListener.onOtherEvent` hook. A more direct approach is add listener directly to the `ExternalCatalog`. ## How was this patch tested? Added the `ExternalCatalogEventSuite`. Author: Herman van Hovell Closes #17710 from hvanhovell/SPARK-20420. --- .../catalyst/catalog/ExternalCatalog.scala | 85 +++++++- .../catalyst/catalog/InMemoryCatalog.scala | 22 +- .../spark/sql/catalyst/catalog/events.scala | 158 +++++++++++++++ .../catalog/ExternalCatalogEventSuite.scala | 188 ++++++++++++++++++ .../spark/sql/internal/SharedState.scala | 7 + .../spark/sql/hive/HiveExternalCatalog.scala | 22 +- 6 files changed, 457 insertions(+), 25 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/events.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogEventSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala index 08a01e860189..974ef900e2ee 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.catalog import org.apache.spark.sql.catalyst.analysis.{FunctionAlreadyExistsException, NoSuchDatabaseException, NoSuchFunctionException, NoSuchTableException} import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.types.StructType +import org.apache.spark.util.ListenerBus /** * Interface for the system catalog (of functions, partitions, tables, and databases). @@ -30,7 +31,8 @@ import org.apache.spark.sql.types.StructType * * Implementations should throw [[NoSuchDatabaseException]] when databases don't exist. */ -abstract class ExternalCatalog { +abstract class ExternalCatalog + extends ListenerBus[ExternalCatalogEventListener, ExternalCatalogEvent] { import CatalogTypes.TablePartitionSpec protected def requireDbExists(db: String): Unit = { @@ -61,9 +63,22 @@ abstract class ExternalCatalog { // Databases // -------------------------------------------------------------------------- - def createDatabase(dbDefinition: CatalogDatabase, ignoreIfExists: Boolean): Unit + final def createDatabase(dbDefinition: CatalogDatabase, ignoreIfExists: Boolean): Unit = { + val db = dbDefinition.name + postToAll(CreateDatabasePreEvent(db)) + doCreateDatabase(dbDefinition, ignoreIfExists) + postToAll(CreateDatabaseEvent(db)) + } + + protected def doCreateDatabase(dbDefinition: CatalogDatabase, ignoreIfExists: Boolean): Unit + + final def dropDatabase(db: String, ignoreIfNotExists: Boolean, cascade: Boolean): Unit = { + postToAll(DropDatabasePreEvent(db)) + doDropDatabase(db, ignoreIfNotExists, cascade) + postToAll(DropDatabaseEvent(db)) + } - def dropDatabase(db: String, ignoreIfNotExists: Boolean, cascade: Boolean): Unit + protected def doDropDatabase(db: String, ignoreIfNotExists: Boolean, cascade: Boolean): Unit /** * Alter a database whose name matches the one specified in `dbDefinition`, @@ -88,11 +103,39 @@ abstract class ExternalCatalog { // Tables // -------------------------------------------------------------------------- - def createTable(tableDefinition: CatalogTable, ignoreIfExists: Boolean): Unit + final def createTable(tableDefinition: CatalogTable, ignoreIfExists: Boolean): Unit = { + val db = tableDefinition.database + val name = tableDefinition.identifier.table + postToAll(CreateTablePreEvent(db, name)) + doCreateTable(tableDefinition, ignoreIfExists) + postToAll(CreateTableEvent(db, name)) + } - def dropTable(db: String, table: String, ignoreIfNotExists: Boolean, purge: Boolean): Unit + protected def doCreateTable(tableDefinition: CatalogTable, ignoreIfExists: Boolean): Unit - def renameTable(db: String, oldName: String, newName: String): Unit + final def dropTable( + db: String, + table: String, + ignoreIfNotExists: Boolean, + purge: Boolean): Unit = { + postToAll(DropTablePreEvent(db, table)) + doDropTable(db, table, ignoreIfNotExists, purge) + postToAll(DropTableEvent(db, table)) + } + + protected def doDropTable( + db: String, + table: String, + ignoreIfNotExists: Boolean, + purge: Boolean): Unit + + final def renameTable(db: String, oldName: String, newName: String): Unit = { + postToAll(RenameTablePreEvent(db, oldName, newName)) + doRenameTable(db, oldName, newName) + postToAll(RenameTableEvent(db, oldName, newName)) + } + + protected def doRenameTable(db: String, oldName: String, newName: String): Unit /** * Alter a table whose database and name match the ones specified in `tableDefinition`, assuming @@ -269,11 +312,30 @@ abstract class ExternalCatalog { // Functions // -------------------------------------------------------------------------- - def createFunction(db: String, funcDefinition: CatalogFunction): Unit + final def createFunction(db: String, funcDefinition: CatalogFunction): Unit = { + val name = funcDefinition.identifier.funcName + postToAll(CreateFunctionPreEvent(db, name)) + doCreateFunction(db, funcDefinition) + postToAll(CreateFunctionEvent(db, name)) + } - def dropFunction(db: String, funcName: String): Unit + protected def doCreateFunction(db: String, funcDefinition: CatalogFunction): Unit - def renameFunction(db: String, oldName: String, newName: String): Unit + final def dropFunction(db: String, funcName: String): Unit = { + postToAll(DropFunctionPreEvent(db, funcName)) + doDropFunction(db, funcName) + postToAll(DropFunctionEvent(db, funcName)) + } + + protected def doDropFunction(db: String, funcName: String): Unit + + final def renameFunction(db: String, oldName: String, newName: String): Unit = { + postToAll(RenameFunctionPreEvent(db, oldName, newName)) + doRenameFunction(db, oldName, newName) + postToAll(RenameFunctionEvent(db, oldName, newName)) + } + + protected def doRenameFunction(db: String, oldName: String, newName: String): Unit def getFunction(db: String, funcName: String): CatalogFunction @@ -281,4 +343,9 @@ abstract class ExternalCatalog { def listFunctions(db: String, pattern: String): Seq[String] + override protected def doPostEvent( + listener: ExternalCatalogEventListener, + event: ExternalCatalogEvent): Unit = { + listener.onEvent(event) + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala index 9ca1c71d1dcb..81dd8efc0015 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala @@ -98,7 +98,7 @@ class InMemoryCatalog( // Databases // -------------------------------------------------------------------------- - override def createDatabase( + override protected def doCreateDatabase( dbDefinition: CatalogDatabase, ignoreIfExists: Boolean): Unit = synchronized { if (catalog.contains(dbDefinition.name)) { @@ -119,7 +119,7 @@ class InMemoryCatalog( } } - override def dropDatabase( + override protected def doDropDatabase( db: String, ignoreIfNotExists: Boolean, cascade: Boolean): Unit = synchronized { @@ -180,7 +180,7 @@ class InMemoryCatalog( // Tables // -------------------------------------------------------------------------- - override def createTable( + override protected def doCreateTable( tableDefinition: CatalogTable, ignoreIfExists: Boolean): Unit = synchronized { assert(tableDefinition.identifier.database.isDefined) @@ -221,7 +221,7 @@ class InMemoryCatalog( } } - override def dropTable( + override protected def doDropTable( db: String, table: String, ignoreIfNotExists: Boolean, @@ -264,7 +264,10 @@ class InMemoryCatalog( } } - override def renameTable(db: String, oldName: String, newName: String): Unit = synchronized { + override protected def doRenameTable( + db: String, + oldName: String, + newName: String): Unit = synchronized { requireTableExists(db, oldName) requireTableNotExists(db, newName) val oldDesc = catalog(db).tables(oldName) @@ -565,18 +568,21 @@ class InMemoryCatalog( // Functions // -------------------------------------------------------------------------- - override def createFunction(db: String, func: CatalogFunction): Unit = synchronized { + override protected def doCreateFunction(db: String, func: CatalogFunction): Unit = synchronized { requireDbExists(db) requireFunctionNotExists(db, func.identifier.funcName) catalog(db).functions.put(func.identifier.funcName, func) } - override def dropFunction(db: String, funcName: String): Unit = synchronized { + override protected def doDropFunction(db: String, funcName: String): Unit = synchronized { requireFunctionExists(db, funcName) catalog(db).functions.remove(funcName) } - override def renameFunction(db: String, oldName: String, newName: String): Unit = synchronized { + override protected def doRenameFunction( + db: String, + oldName: String, + newName: String): Unit = synchronized { requireFunctionExists(db, oldName) requireFunctionNotExists(db, newName) val newFunc = getFunction(db, oldName).copy(identifier = FunctionIdentifier(newName, Some(db))) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/events.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/events.scala new file mode 100644 index 000000000000..459973a13bb1 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/events.scala @@ -0,0 +1,158 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.catalog + +import org.apache.spark.scheduler.SparkListenerEvent + +/** + * Event emitted by the external catalog when it is modified. Events are either fired before or + * after the modification (the event should document this). + */ +trait ExternalCatalogEvent extends SparkListenerEvent + +/** + * Listener interface for external catalog modification events. + */ +trait ExternalCatalogEventListener { + def onEvent(event: ExternalCatalogEvent): Unit +} + +/** + * Event fired when a database is create or dropped. + */ +trait DatabaseEvent extends ExternalCatalogEvent { + /** + * Database of the object that was touched. + */ + val database: String +} + +/** + * Event fired before a database is created. + */ +case class CreateDatabasePreEvent(database: String) extends DatabaseEvent + +/** + * Event fired after a database has been created. + */ +case class CreateDatabaseEvent(database: String) extends DatabaseEvent + +/** + * Event fired before a database is dropped. + */ +case class DropDatabasePreEvent(database: String) extends DatabaseEvent + +/** + * Event fired after a database has been dropped. + */ +case class DropDatabaseEvent(database: String) extends DatabaseEvent + +/** + * Event fired when a table is created, dropped or renamed. + */ +trait TableEvent extends DatabaseEvent { + /** + * Name of the table that was touched. + */ + val name: String +} + +/** + * Event fired before a table is created. + */ +case class CreateTablePreEvent(database: String, name: String) extends TableEvent + +/** + * Event fired after a table has been created. + */ +case class CreateTableEvent(database: String, name: String) extends TableEvent + +/** + * Event fired before a table is dropped. + */ +case class DropTablePreEvent(database: String, name: String) extends TableEvent + +/** + * Event fired after a table has been dropped. + */ +case class DropTableEvent(database: String, name: String) extends TableEvent + +/** + * Event fired before a table is renamed. + */ +case class RenameTablePreEvent( + database: String, + name: String, + newName: String) + extends TableEvent + +/** + * Event fired after a table has been renamed. + */ +case class RenameTableEvent( + database: String, + name: String, + newName: String) + extends TableEvent + +/** + * Event fired when a function is created, dropped or renamed. + */ +trait FunctionEvent extends DatabaseEvent { + /** + * Name of the function that was touched. + */ + val name: String +} + +/** + * Event fired before a function is created. + */ +case class CreateFunctionPreEvent(database: String, name: String) extends FunctionEvent + +/** + * Event fired after a function has been created. + */ +case class CreateFunctionEvent(database: String, name: String) extends FunctionEvent + +/** + * Event fired before a function is dropped. + */ +case class DropFunctionPreEvent(database: String, name: String) extends FunctionEvent + +/** + * Event fired after a function has been dropped. + */ +case class DropFunctionEvent(database: String, name: String) extends FunctionEvent + +/** + * Event fired before a function is renamed. + */ +case class RenameFunctionPreEvent( + database: String, + name: String, + newName: String) + extends FunctionEvent + +/** + * Event fired after a function has been renamed. + */ +case class RenameFunctionEvent( + database: String, + name: String, + newName: String) + extends FunctionEvent diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogEventSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogEventSuite.scala new file mode 100644 index 000000000000..2539ea615ff9 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogEventSuite.scala @@ -0,0 +1,188 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.catalog + +import java.net.URI +import java.nio.file.{Files, Path} + +import scala.collection.mutable + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} +import org.apache.spark.sql.types.StructType + +/** + * Test Suite for external catalog events + */ +class ExternalCatalogEventSuite extends SparkFunSuite { + + protected def newCatalog: ExternalCatalog = new InMemoryCatalog() + + private def testWithCatalog( + name: String)( + f: (ExternalCatalog, Seq[ExternalCatalogEvent] => Unit) => Unit): Unit = test(name) { + val catalog = newCatalog + val recorder = mutable.Buffer.empty[ExternalCatalogEvent] + catalog.addListener(new ExternalCatalogEventListener { + override def onEvent(event: ExternalCatalogEvent): Unit = { + recorder += event + } + }) + f(catalog, (expected: Seq[ExternalCatalogEvent]) => { + val actual = recorder.clone() + recorder.clear() + assert(expected === actual) + }) + } + + private def createDbDefinition(uri: URI): CatalogDatabase = { + CatalogDatabase(name = "db5", description = "", locationUri = uri, Map.empty) + } + + private def createDbDefinition(): CatalogDatabase = { + createDbDefinition(preparePath(Files.createTempDirectory("db_"))) + } + + private def preparePath(path: Path): URI = path.normalize().toUri + + testWithCatalog("database") { (catalog, checkEvents) => + // CREATE + val dbDefinition = createDbDefinition() + + catalog.createDatabase(dbDefinition, ignoreIfExists = false) + checkEvents(CreateDatabasePreEvent("db5") :: CreateDatabaseEvent("db5") :: Nil) + + catalog.createDatabase(dbDefinition, ignoreIfExists = true) + checkEvents(CreateDatabasePreEvent("db5") :: CreateDatabaseEvent("db5") :: Nil) + + intercept[AnalysisException] { + catalog.createDatabase(dbDefinition, ignoreIfExists = false) + } + checkEvents(CreateDatabasePreEvent("db5") :: Nil) + + // DROP + intercept[AnalysisException] { + catalog.dropDatabase("db4", ignoreIfNotExists = false, cascade = false) + } + checkEvents(DropDatabasePreEvent("db4") :: Nil) + + catalog.dropDatabase("db5", ignoreIfNotExists = false, cascade = false) + checkEvents(DropDatabasePreEvent("db5") :: DropDatabaseEvent("db5") :: Nil) + + catalog.dropDatabase("db4", ignoreIfNotExists = true, cascade = false) + checkEvents(DropDatabasePreEvent("db4") :: DropDatabaseEvent("db4") :: Nil) + } + + testWithCatalog("table") { (catalog, checkEvents) => + val path1 = Files.createTempDirectory("db_") + val path2 = Files.createTempDirectory(path1, "tbl_") + val uri1 = preparePath(path1) + val uri2 = preparePath(path2) + + // CREATE + val dbDefinition = createDbDefinition(uri1) + + val storage = CatalogStorageFormat.empty.copy( + locationUri = Option(uri2)) + val tableDefinition = CatalogTable( + identifier = TableIdentifier("tbl1", Some("db5")), + tableType = CatalogTableType.MANAGED, + storage = storage, + schema = new StructType().add("id", "long")) + + catalog.createDatabase(dbDefinition, ignoreIfExists = false) + checkEvents(CreateDatabasePreEvent("db5") :: CreateDatabaseEvent("db5") :: Nil) + + catalog.createTable(tableDefinition, ignoreIfExists = false) + checkEvents(CreateTablePreEvent("db5", "tbl1") :: CreateTableEvent("db5", "tbl1") :: Nil) + + catalog.createTable(tableDefinition, ignoreIfExists = true) + checkEvents(CreateTablePreEvent("db5", "tbl1") :: CreateTableEvent("db5", "tbl1") :: Nil) + + intercept[AnalysisException] { + catalog.createTable(tableDefinition, ignoreIfExists = false) + } + checkEvents(CreateTablePreEvent("db5", "tbl1") :: Nil) + + // RENAME + catalog.renameTable("db5", "tbl1", "tbl2") + checkEvents( + RenameTablePreEvent("db5", "tbl1", "tbl2") :: + RenameTableEvent("db5", "tbl1", "tbl2") :: Nil) + + intercept[AnalysisException] { + catalog.renameTable("db5", "tbl1", "tbl2") + } + checkEvents(RenameTablePreEvent("db5", "tbl1", "tbl2") :: Nil) + + // DROP + intercept[AnalysisException] { + catalog.dropTable("db5", "tbl1", ignoreIfNotExists = false, purge = true) + } + checkEvents(DropTablePreEvent("db5", "tbl1") :: Nil) + + catalog.dropTable("db5", "tbl2", ignoreIfNotExists = false, purge = true) + checkEvents(DropTablePreEvent("db5", "tbl2") :: DropTableEvent("db5", "tbl2") :: Nil) + + catalog.dropTable("db5", "tbl2", ignoreIfNotExists = true, purge = true) + checkEvents(DropTablePreEvent("db5", "tbl2") :: DropTableEvent("db5", "tbl2") :: Nil) + } + + testWithCatalog("function") { (catalog, checkEvents) => + // CREATE + val dbDefinition = createDbDefinition() + + val functionDefinition = CatalogFunction( + identifier = FunctionIdentifier("fn7", Some("db5")), + className = "", + resources = Seq.empty) + + val newIdentifier = functionDefinition.identifier.copy(funcName = "fn4") + val renamedFunctionDefinition = functionDefinition.copy(identifier = newIdentifier) + + catalog.createDatabase(dbDefinition, ignoreIfExists = false) + checkEvents(CreateDatabasePreEvent("db5") :: CreateDatabaseEvent("db5") :: Nil) + + catalog.createFunction("db5", functionDefinition) + checkEvents(CreateFunctionPreEvent("db5", "fn7") :: CreateFunctionEvent("db5", "fn7") :: Nil) + + intercept[AnalysisException] { + catalog.createFunction("db5", functionDefinition) + } + checkEvents(CreateFunctionPreEvent("db5", "fn7") :: Nil) + + // RENAME + catalog.renameFunction("db5", "fn7", "fn4") + checkEvents( + RenameFunctionPreEvent("db5", "fn7", "fn4") :: + RenameFunctionEvent("db5", "fn7", "fn4") :: Nil) + intercept[AnalysisException] { + catalog.renameFunction("db5", "fn7", "fn4") + } + checkEvents(RenameFunctionPreEvent("db5", "fn7", "fn4") :: Nil) + + // DROP + intercept[AnalysisException] { + catalog.dropFunction("db5", "fn7") + } + checkEvents(DropFunctionPreEvent("db5", "fn7") :: Nil) + + catalog.dropFunction("db5", "fn4") + checkEvents(DropFunctionPreEvent("db5", "fn4") :: DropFunctionEvent("db5", "fn4") :: Nil) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala index d06dbaa2d0ab..f834569e59b7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala @@ -109,6 +109,13 @@ private[sql] class SharedState(val sparkContext: SparkContext) extends Logging { } } + // Make sure we propagate external catalog events to the spark listener bus + externalCatalog.addListener(new ExternalCatalogEventListener { + override def onEvent(event: ExternalCatalogEvent): Unit = { + sparkContext.listenerBus.post(event) + } + }) + /** * A manager for global temporary views. */ diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index 8b0fdf49cefa..71e33c46b9ae 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -141,13 +141,13 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat // Databases // -------------------------------------------------------------------------- - override def createDatabase( + override protected def doCreateDatabase( dbDefinition: CatalogDatabase, ignoreIfExists: Boolean): Unit = withClient { client.createDatabase(dbDefinition, ignoreIfExists) } - override def dropDatabase( + override protected def doDropDatabase( db: String, ignoreIfNotExists: Boolean, cascade: Boolean): Unit = withClient { @@ -194,7 +194,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat // Tables // -------------------------------------------------------------------------- - override def createTable( + override protected def doCreateTable( tableDefinition: CatalogTable, ignoreIfExists: Boolean): Unit = withClient { assert(tableDefinition.identifier.database.isDefined) @@ -456,7 +456,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat } } - override def dropTable( + override protected def doDropTable( db: String, table: String, ignoreIfNotExists: Boolean, @@ -465,7 +465,10 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat client.dropTable(db, table, ignoreIfNotExists, purge) } - override def renameTable(db: String, oldName: String, newName: String): Unit = withClient { + override protected def doRenameTable( + db: String, + oldName: String, + newName: String): Unit = withClient { val rawTable = getRawTable(db, oldName) // Note that Hive serde tables don't use path option in storage properties to store the value @@ -1056,7 +1059,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat // Functions // -------------------------------------------------------------------------- - override def createFunction( + override protected def doCreateFunction( db: String, funcDefinition: CatalogFunction): Unit = withClient { requireDbExists(db) @@ -1069,12 +1072,15 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat client.createFunction(db, funcDefinition.copy(identifier = functionIdentifier)) } - override def dropFunction(db: String, name: String): Unit = withClient { + override protected def doDropFunction(db: String, name: String): Unit = withClient { requireFunctionExists(db, name) client.dropFunction(db, name) } - override def renameFunction(db: String, oldName: String, newName: String): Unit = withClient { + override protected def doRenameFunction( + db: String, + oldName: String, + newName: String): Unit = withClient { requireFunctionExists(db, oldName) requireFunctionNotExists(db, newName) client.renameFunction(db, oldName, newName) From 34767997e0c6cb28e1fac8cb650fa3511f260ca5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Herv=C3=A9?= Date: Fri, 21 Apr 2017 08:52:18 +0100 Subject: [PATCH 326/512] Small rewording about history server use case MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Hello PR #10991 removed the built-in history view from Spark Standalone, so the history server is no longer useful to Yarn or Mesos only. Author: Hervé Closes #17709 from dud225/patch-1. --- docs/monitoring.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/monitoring.md b/docs/monitoring.md index da954385dc45..3e577c5f3677 100644 --- a/docs/monitoring.md +++ b/docs/monitoring.md @@ -27,8 +27,8 @@ in the UI to persisted storage. ## Viewing After the Fact -If Spark is run on Mesos or YARN, it is still possible to construct the UI of an -application through Spark's history server, provided that the application's event logs exist. +It is still possible to construct the UI of an application through Spark's history server, +provided that the application's event logs exist. You can start the history server by executing: ./sbin/start-history-server.sh From c9e6035e1fb825d280eaec3bdfc1e4d362897ffd Mon Sep 17 00:00:00 2001 From: Juliusz Sompolski Date: Fri, 21 Apr 2017 22:11:24 +0800 Subject: [PATCH 327/512] [SPARK-20412] Throw ParseException from visitNonOptionalPartitionSpec instead of returning null values. ## What changes were proposed in this pull request? If a partitionSpec is supposed to not contain optional values, a ParseException should be thrown, and not nulls returned. The nulls can later cause NullPointerExceptions in places not expecting them. ## How was this patch tested? A query like "SHOW PARTITIONS tbl PARTITION(col1='val1', col2)" used to throw a NullPointerException. Now it throws a ParseException. Author: Juliusz Sompolski Closes #17707 from juliuszsompolski/SPARK-20412. --- .../spark/sql/catalyst/parser/AstBuilder.scala | 5 ++++- .../sql/execution/command/DDLCommandSuite.scala | 16 ++++++++++++---- 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index e1db1ef5b869..2cf06d15664d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -215,7 +215,10 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { */ protected def visitNonOptionalPartitionSpec( ctx: PartitionSpecContext): Map[String, String] = withOrigin(ctx) { - visitPartitionSpec(ctx).mapValues(_.orNull).map(identity) + visitPartitionSpec(ctx).map { + case (key, None) => throw new ParseException(s"Found an empty partition key '$key'.", ctx) + case (key, Some(value)) => key -> value + } } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala index 97c61dc8694b..8a6bc62fec96 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala @@ -530,13 +530,13 @@ class DDLCommandSuite extends PlanTest { """.stripMargin val sql4 = """ - |ALTER TABLE table_name PARTITION (test, dt='2008-08-08', + |ALTER TABLE table_name PARTITION (test=1, dt='2008-08-08', |country='us') SET SERDE 'org.apache.class' WITH SERDEPROPERTIES ('columns'='foo,bar', |'field.delim' = ',') """.stripMargin val sql5 = """ - |ALTER TABLE table_name PARTITION (test, dt='2008-08-08', + |ALTER TABLE table_name PARTITION (test=1, dt='2008-08-08', |country='us') SET SERDEPROPERTIES ('columns'='foo,bar', 'field.delim' = ',') """.stripMargin val parsed1 = parser.parsePlan(sql1) @@ -558,12 +558,12 @@ class DDLCommandSuite extends PlanTest { tableIdent, Some("org.apache.class"), Some(Map("columns" -> "foo,bar", "field.delim" -> ",")), - Some(Map("test" -> null, "dt" -> "2008-08-08", "country" -> "us"))) + Some(Map("test" -> "1", "dt" -> "2008-08-08", "country" -> "us"))) val expected5 = AlterTableSerDePropertiesCommand( tableIdent, None, Some(Map("columns" -> "foo,bar", "field.delim" -> ",")), - Some(Map("test" -> null, "dt" -> "2008-08-08", "country" -> "us"))) + Some(Map("test" -> "1", "dt" -> "2008-08-08", "country" -> "us"))) comparePlans(parsed1, expected1) comparePlans(parsed2, expected2) comparePlans(parsed3, expected3) @@ -832,6 +832,14 @@ class DDLCommandSuite extends PlanTest { assert(e.contains("Found duplicate keys 'a'")) } + test("empty values in non-optional partition specs") { + val e = intercept[ParseException] { + parser.parsePlan( + "SHOW PARTITIONS dbx.tab1 PARTITION (a='1', b)") + }.getMessage + assert(e.contains("Found an empty partition key 'b'")) + } + test("drop table") { val tableName1 = "db.tab" val tableName2 = "tab" From a750a595976791cb8a77063f690ea8f82ea75a8f Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Fri, 21 Apr 2017 22:25:35 +0800 Subject: [PATCH 328/512] [SPARK-20341][SQL] Support BigInt's value that does not fit in long value range ## What changes were proposed in this pull request? This PR avoids an exception in the case where `scala.math.BigInt` has a value that does not fit into long value range (e.g. `Long.MAX_VALUE+1`). When we run the following code by using the current Spark, the following exception is thrown. This PR keeps the value using `BigDecimal` if we detect such an overflow case by catching `ArithmeticException`. Sample program: ``` case class BigIntWrapper(value:scala.math.BigInt)``` spark.createDataset(BigIntWrapper(scala.math.BigInt("10000000000000000002"))::Nil).show ``` Exception: ``` Error while encoding: java.lang.ArithmeticException: BigInteger out of long range staticinvoke(class org.apache.spark.sql.types.Decimal$, DecimalType(38,0), apply, assertnotnull(assertnotnull(input[0, org.apache.spark.sql.BigIntWrapper, true])).value, true) AS value#0 java.lang.RuntimeException: Error while encoding: java.lang.ArithmeticException: BigInteger out of long range staticinvoke(class org.apache.spark.sql.types.Decimal$, DecimalType(38,0), apply, assertnotnull(assertnotnull(input[0, org.apache.spark.sql.BigIntWrapper, true])).value, true) AS value#0 at org.apache.spark.sql.catalyst.encoders.ExpressionEncoder.toRow(ExpressionEncoder.scala:290) at org.apache.spark.sql.SparkSession$$anonfun$2.apply(SparkSession.scala:454) at org.apache.spark.sql.SparkSession$$anonfun$2.apply(SparkSession.scala:454) at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234) at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234) at scala.collection.immutable.List.foreach(List.scala:381) at scala.collection.TraversableLike$class.map(TraversableLike.scala:234) at scala.collection.immutable.List.map(List.scala:285) at org.apache.spark.sql.SparkSession.createDataset(SparkSession.scala:454) at org.apache.spark.sql.Agg$$anonfun$18.apply$mcV$sp(MySuite.scala:192) at org.apache.spark.sql.Agg$$anonfun$18.apply(MySuite.scala:192) at org.apache.spark.sql.Agg$$anonfun$18.apply(MySuite.scala:192) at org.scalatest.Transformer$$anonfun$apply$1.apply$mcV$sp(Transformer.scala:22) at org.scalatest.OutcomeOf$class.outcomeOf(OutcomeOf.scala:85) at org.scalatest.OutcomeOf$.outcomeOf(OutcomeOf.scala:104) at org.scalatest.Transformer.apply(Transformer.scala:22) at org.scalatest.Transformer.apply(Transformer.scala:20) at org.scalatest.FunSuiteLike$$anon$1.apply(FunSuiteLike.scala:166) at org.apache.spark.SparkFunSuite.withFixture(SparkFunSuite.scala:68) at org.scalatest.FunSuiteLike$class.invokeWithFixture$1(FunSuiteLike.scala:163) at org.scalatest.FunSuiteLike$$anonfun$runTest$1.apply(FunSuiteLike.scala:175) at org.scalatest.FunSuiteLike$$anonfun$runTest$1.apply(FunSuiteLike.scala:175) at org.scalatest.SuperEngine.runTestImpl(Engine.scala:306) at org.scalatest.FunSuiteLike$class.runTest(FunSuiteLike.scala:175) ... Caused by: java.lang.ArithmeticException: BigInteger out of long range at java.math.BigInteger.longValueExact(BigInteger.java:4531) at org.apache.spark.sql.types.Decimal.set(Decimal.scala:140) at org.apache.spark.sql.types.Decimal$.apply(Decimal.scala:434) at org.apache.spark.sql.types.Decimal.apply(Decimal.scala) at org.apache.spark.sql.catalyst.expressions.GeneratedClass$SpecificUnsafeProjection.apply(Unknown Source) at org.apache.spark.sql.catalyst.encoders.ExpressionEncoder.toRow(ExpressionEncoder.scala:287) ... 59 more ``` ## How was this patch tested? Add new test suite into `DecimalSuite` Author: Kazuaki Ishizaki Closes #17684 from kiszk/SPARK-20341. --- .../org/apache/spark/sql/types/Decimal.scala | 20 +++++++++++++------ .../apache/spark/sql/types/DecimalSuite.scala | 6 ++++++ 2 files changed, 20 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala index e8f6884c025c..80916ee9c537 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -132,14 +132,22 @@ final class Decimal extends Ordered[Decimal] with Serializable { } /** - * Set this Decimal to the given BigInteger value. Will have precision 38 and scale 0. + * If the value is not in the range of long, convert it to BigDecimal and + * the precision and scale are based on the converted value. + * + * This code avoids BigDecimal object allocation as possible to improve runtime efficiency */ def set(bigintval: BigInteger): Decimal = { - this.decimalVal = null - this.longVal = bigintval.longValueExact() - this._precision = DecimalType.MAX_PRECISION - this._scale = 0 - this + try { + this.decimalVal = null + this.longVal = bigintval.longValueExact() + this._precision = DecimalType.MAX_PRECISION + this._scale = 0 + this + } catch { + case _: ArithmeticException => + set(BigDecimal(bigintval)) + } } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala index 714883a4099c..93c231e30b49 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala @@ -212,4 +212,10 @@ class DecimalSuite extends SparkFunSuite with PrivateMethodTester { } } } + + test("SPARK-20341: support BigInt's value does not fit in long value range") { + val bigInt = scala.math.BigInt("9223372036854775808") + val decimal = Decimal.apply(bigInt) + assert(decimal.toJavaBigDecimal.unscaledValue.toString === "9223372036854775808") + } } From eb00378f0eed6afbf328ae6cd541cc202d14c1f0 Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Fri, 21 Apr 2017 17:58:13 +0000 Subject: [PATCH 329/512] [SPARK-20423][ML] fix MLOR coeffs centering when reg == 0 ## What changes were proposed in this pull request? When reg == 0, MLOR has multiple solutions and we need to centralize the coeffs to get identical result. BUT current implementation centralize the `coefficientMatrix` by the global coeffs means. In fact the `coefficientMatrix` should be centralized on each feature index itself. Because, according to the MLOR probability distribution function, it can be proven easily that: suppose `{ w0, w1, .. w(K-1) }` make up the `coefficientMatrix`, then `{ w0 + c, w1 + c, ... w(K - 1) + c}` will also be the equivalent solution. `c` is an arbitrary vector of `numFeatures` dimension. reference https://core.ac.uk/download/pdf/6287975.pdf So that we need to centralize the `coefficientMatrix` on each feature dimension separately. **We can also confirm this through R library `glmnet`, that MLOR in `glmnet` always generate coefficients result that the sum of each dimension is all `zero`, when reg == 0.** ## How was this patch tested? Tests added. Author: WeichenXu Closes #17706 from WeichenXu123/mlor_center. --- .../spark/ml/classification/LogisticRegression.scala | 11 ++++++++--- .../ml/classification/LogisticRegressionSuite.scala | 6 ++++++ 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 965ce3d6f275..bc8154692e52 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -609,9 +609,14 @@ class LogisticRegression @Since("1.2.0") ( Friedman, et al. "Regularization Paths for Generalized Linear Models via Coordinate Descent," https://core.ac.uk/download/files/153/6287975.pdf */ - val denseValues = denseCoefficientMatrix.values - val coefficientMean = denseValues.sum / denseValues.length - denseCoefficientMatrix.update(_ - coefficientMean) + val centers = Array.fill(numFeatures)(0.0) + denseCoefficientMatrix.foreachActive { case (i, j, v) => + centers(j) += v + } + centers.transform(_ / numCoefficientSets) + denseCoefficientMatrix.foreachActive { case (i, j, v) => + denseCoefficientMatrix.update(i, j, v - centers(j)) + } } // center the intercepts when using multinomial algorithm diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index c858b9bbfc25..83f575e83828 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -1139,6 +1139,9 @@ class LogisticRegressionSuite 0.10095851, -0.85897154, 0.08392798, 0.07904499), isTransposed = true) val interceptsR = Vectors.dense(-2.10320093, 0.3394473, 1.76375361) + model1.coefficientMatrix.colIter.foreach(v => assert(v.toArray.sum ~== 0.0 absTol eps)) + model2.coefficientMatrix.colIter.foreach(v => assert(v.toArray.sum ~== 0.0 absTol eps)) + assert(model1.coefficientMatrix ~== coefficientsR relTol 0.05) assert(model1.coefficientMatrix.toArray.sum ~== 0.0 absTol eps) assert(model1.interceptVector ~== interceptsR relTol 0.05) @@ -1204,6 +1207,9 @@ class LogisticRegressionSuite -0.3180040, 0.9679074, -0.2252219, -0.4319914, 0.2452411, -0.6046524, 0.1050710, 0.1180180), isTransposed = true) + model1.coefficientMatrix.colIter.foreach(v => assert(v.toArray.sum ~== 0.0 absTol eps)) + model2.coefficientMatrix.colIter.foreach(v => assert(v.toArray.sum ~== 0.0 absTol eps)) + assert(model1.coefficientMatrix ~== coefficientsR relTol 0.05) assert(model1.coefficientMatrix.toArray.sum ~== 0.0 absTol eps) assert(model1.interceptVector.toArray === Array.fill(3)(0.0)) From fd648bff63f91a30810910dfc5664eea0ff5e6f9 Mon Sep 17 00:00:00 2001 From: zero323 Date: Fri, 21 Apr 2017 12:06:21 -0700 Subject: [PATCH 330/512] [SPARK-20371][R] Add wrappers for collect_list and collect_set ## What changes were proposed in this pull request? Adds wrappers for `collect_list` and `collect_set`. ## How was this patch tested? Unit tests, `check-cran.sh` Author: zero323 Closes #17672 from zero323/SPARK-20371. --- R/pkg/NAMESPACE | 2 ++ R/pkg/R/functions.R | 40 +++++++++++++++++++++++ R/pkg/R/generics.R | 9 +++++ R/pkg/inst/tests/testthat/test_sparkSQL.R | 22 +++++++++++++ 4 files changed, 73 insertions(+) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index b6b559adf06e..e804e30e14b8 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -203,6 +203,8 @@ exportMethods("%in%", "cbrt", "ceil", "ceiling", + "collect_list", + "collect_set", "column", "concat", "concat_ws", diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index f854df11e576..e7decb91867b 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -3705,3 +3705,43 @@ setMethod("create_map", jc <- callJStatic("org.apache.spark.sql.functions", "map", jcols) column(jc) }) + +#' collect_list +#' +#' Creates a list of objects with duplicates. +#' +#' @param x Column to compute on +#' +#' @rdname collect_list +#' @name collect_list +#' @family agg_funcs +#' @aliases collect_list,Column-method +#' @export +#' @examples \dontrun{collect_list(df$x)} +#' @note collect_list since 2.3.0 +setMethod("collect_list", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "collect_list", x@jc) + column(jc) + }) + +#' collect_set +#' +#' Creates a list of objects with duplicate elements eliminated. +#' +#' @param x Column to compute on +#' +#' @rdname collect_set +#' @name collect_set +#' @family agg_funcs +#' @aliases collect_set,Column-method +#' @export +#' @examples \dontrun{collect_set(df$x)} +#' @note collect_set since 2.3.0 +setMethod("collect_set", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "collect_set", x@jc) + column(jc) + }) diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index da46823f52a1..61d248ebd2e3 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -918,6 +918,14 @@ setGeneric("cbrt", function(x) { standardGeneric("cbrt") }) #' @export setGeneric("ceil", function(x) { standardGeneric("ceil") }) +#' @rdname collect_list +#' @export +setGeneric("collect_list", function(x) { standardGeneric("collect_list") }) + +#' @rdname collect_set +#' @export +setGeneric("collect_set", function(x) { standardGeneric("collect_set") }) + #' @rdname column #' @export setGeneric("column", function(x) { standardGeneric("column") }) @@ -1358,6 +1366,7 @@ setGeneric("window", function(x, ...) { standardGeneric("window") }) #' @export setGeneric("year", function(x) { standardGeneric("year") }) + ###################### Spark.ML Methods ########################## #' @rdname fitted diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index 9e87a4710699..bf2093fdc475 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -1731,6 +1731,28 @@ test_that("group by, agg functions", { expect_true(abs(sd(1:2) - 0.7071068) < 1e-6) expect_true(abs(var(1:5, 1:5) - 2.5) < 1e-6) + # Test collect_list and collect_set + gd3_collections_local <- collect( + agg(gd3, collect_set(df8$age), collect_list(df8$age)) + ) + + expect_equal( + unlist(gd3_collections_local[gd3_collections_local$name == "Andy", 2]), + c(30) + ) + + expect_equal( + unlist(gd3_collections_local[gd3_collections_local$name == "Andy", 3]), + c(30, 30) + ) + + expect_equal( + sort(unlist( + gd3_collections_local[gd3_collections_local$name == "Justin", 3] + )), + c(1, 19) + ) + unlink(jsonPath2) unlink(jsonPath3) }) From ad290402aa1d609abf5a2883a6d87fa8bc2bd517 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=83=AD=E5=B0=8F=E9=BE=99=2010207633?= Date: Fri, 21 Apr 2017 20:08:26 +0100 Subject: [PATCH 331/512] [SPARK-20401][DOC] In the spark official configuration document, the 'spark.driver.supervise' configuration parameter specification and default values are necessary. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? Use the REST interface submits the spark job. e.g. curl -X POST http://10.43.183.120:6066/v1/submissions/create --header "Content-Type:application/json;charset=UTF-8" --data'{ "action": "CreateSubmissionRequest", "appArgs": [ "myAppArgument" ], "appResource": "/home/mr/gxl/test.jar", "clientSparkVersion": "2.2.0", "environmentVariables": { "SPARK_ENV_LOADED": "1" }, "mainClass": "cn.zte.HdfsTest", "sparkProperties": { "spark.jars": "/home/mr/gxl/test.jar", **"spark.driver.supervise": "true",** "spark.app.name": "HdfsTest", "spark.eventLog.enabled": "false", "spark.submit.deployMode": "cluster", "spark.master": "spark://10.43.183.120:6066" } }' **I hope that make sure that the driver is automatically restarted if it fails with non-zero exit code. But I can not find the 'spark.driver.supervise' configuration parameter specification and default values from the spark official document.** ## How was this patch tested? manual tests Please review http://spark.apache.org/contributing.html before opening a pull request. Author: 郭小龙 10207633 Author: guoxiaolong Author: guoxiaolongzte Closes #17696 from guoxiaolongzte/SPARK-20401. --- docs/configuration.md | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/docs/configuration.md b/docs/configuration.md index 2687f542b8bd..6b65d2bcb83e 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -213,6 +213,14 @@ of the most common options to set are: and typically can have up to 50 characters. + + spark.driver.supervise + false + + If true, restarts the driver automatically if it fails with a non-zero exit status. + Only has effect in Spark standalone mode or Mesos cluster deploy mode. + + Apart from these, the following properties are also available, and may be useful in some situations: From 05a451491d535c0828413ce2eb06fe94571069ac Mon Sep 17 00:00:00 2001 From: eatoncys Date: Sat, 22 Apr 2017 12:29:35 +0100 Subject: [PATCH 332/512] [SPARK-20386][SPARK CORE] modify the log info if the block exists on the slave already ## What changes were proposed in this pull request? Modify the added memory size to memSize-originalMemSize if the block exists on the slave already since if the block exists, the added memory size should be memSize-originalMemSize; if originalMemSize is bigger than memSize ,then the log info should be Removed memory, removed size should be originalMemSize-memSize ## How was this patch tested? Multiple runs on existing unit tests (Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) (If this patch involves UI changes, please attach a screenshot; otherwise, remove this) Please review http://spark.apache.org/contributing.html before opening a pull request. Author: eatoncys Closes #17683 from eatoncys/SPARK-20386. --- .../storage/BlockManagerMasterEndpoint.scala | 52 +++++++++++++------ 1 file changed, 35 insertions(+), 17 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala index 467c3e0e6b51..6f85b9e4d6c7 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala @@ -497,11 +497,17 @@ private[spark] class BlockManagerInfo( updateLastSeenMs() - if (_blocks.containsKey(blockId)) { + val blockExists = _blocks.containsKey(blockId) + var originalMemSize: Long = 0 + var originalDiskSize: Long = 0 + var originalLevel: StorageLevel = StorageLevel.NONE + + if (blockExists) { // The block exists on the slave already. val blockStatus: BlockStatus = _blocks.get(blockId) - val originalLevel: StorageLevel = blockStatus.storageLevel - val originalMemSize: Long = blockStatus.memSize + originalLevel = blockStatus.storageLevel + originalMemSize = blockStatus.memSize + originalDiskSize = blockStatus.diskSize if (originalLevel.useMemory) { _remainingMem += originalMemSize @@ -520,32 +526,44 @@ private[spark] class BlockManagerInfo( blockStatus = BlockStatus(storageLevel, memSize = memSize, diskSize = 0) _blocks.put(blockId, blockStatus) _remainingMem -= memSize - logInfo("Added %s in memory on %s (size: %s, free: %s)".format( - blockId, blockManagerId.hostPort, Utils.bytesToString(memSize), - Utils.bytesToString(_remainingMem))) + if (blockExists) { + logInfo(s"Updated $blockId in memory on ${blockManagerId.hostPort}" + + s" (current size: ${Utils.bytesToString(memSize)}," + + s" original size: ${Utils.bytesToString(originalMemSize)}," + + s" free: ${Utils.bytesToString(_remainingMem)})") + } else { + logInfo(s"Added $blockId in memory on ${blockManagerId.hostPort}" + + s" (size: ${Utils.bytesToString(memSize)}," + + s" free: ${Utils.bytesToString(_remainingMem)})") + } } if (storageLevel.useDisk) { blockStatus = BlockStatus(storageLevel, memSize = 0, diskSize = diskSize) _blocks.put(blockId, blockStatus) - logInfo("Added %s on disk on %s (size: %s)".format( - blockId, blockManagerId.hostPort, Utils.bytesToString(diskSize))) + if (blockExists) { + logInfo(s"Updated $blockId on disk on ${blockManagerId.hostPort}" + + s" (current size: ${Utils.bytesToString(diskSize)}," + + s" original size: ${Utils.bytesToString(originalDiskSize)})") + } else { + logInfo(s"Added $blockId on disk on ${blockManagerId.hostPort}" + + s" (size: ${Utils.bytesToString(diskSize)})") + } } if (!blockId.isBroadcast && blockStatus.isCached) { _cachedBlocks += blockId } - } else if (_blocks.containsKey(blockId)) { + } else if (blockExists) { // If isValid is not true, drop the block. - val blockStatus: BlockStatus = _blocks.get(blockId) _blocks.remove(blockId) _cachedBlocks -= blockId - if (blockStatus.storageLevel.useMemory) { - logInfo("Removed %s on %s in memory (size: %s, free: %s)".format( - blockId, blockManagerId.hostPort, Utils.bytesToString(blockStatus.memSize), - Utils.bytesToString(_remainingMem))) + if (originalLevel.useMemory) { + logInfo(s"Removed $blockId on ${blockManagerId.hostPort} in memory" + + s" (size: ${Utils.bytesToString(originalMemSize)}," + + s" free: ${Utils.bytesToString(_remainingMem)})") } - if (blockStatus.storageLevel.useDisk) { - logInfo("Removed %s on %s on disk (size: %s)".format( - blockId, blockManagerId.hostPort, Utils.bytesToString(blockStatus.diskSize))) + if (originalLevel.useDisk) { + logInfo(s"Removed $blockId on ${blockManagerId.hostPort} on disk" + + s" (size: ${Utils.bytesToString(originalDiskSize)})") } } } From b3c572a6b332b79fef72c309b9038b3c939dcba2 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Sat, 22 Apr 2017 09:41:58 -0700 Subject: [PATCH 333/512] [SPARK-20430][SQL] Initialise RangeExec parameters in a driver side ## What changes were proposed in this pull request? This pr initialised `RangeExec` parameters in a driver side. In the current master, a query below throws `NullPointerException`; ``` sql("SET spark.sql.codegen.wholeStage=false") sql("SELECT * FROM range(1)").show 17/04/20 17:11:05 ERROR Executor: Exception in task 0.0 in stage 0.0 (TID 0) java.lang.NullPointerException at org.apache.spark.sql.execution.SparkPlan.sparkContext(SparkPlan.scala:54) at org.apache.spark.sql.execution.RangeExec.numSlices(basicPhysicalOperators.scala:343) at org.apache.spark.sql.execution.RangeExec$$anonfun$20.apply(basicPhysicalOperators.scala:506) at org.apache.spark.sql.execution.RangeExec$$anonfun$20.apply(basicPhysicalOperators.scala:505) at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsWithIndex$1$$anonfun$apply$26.apply(RDD.scala:844) at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsWithIndex$1$$anonfun$apply$26.apply(RDD.scala:844) at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:38) at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:323) at org.apache.spark.rdd.RDD.iterator(RDD.scala:287) at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:38) at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:323) at org.apache.spark.rdd.RDD.iterator(RDD.scala:287) at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:87) at org.apache.spark.scheduler.Task.run(Task.scala:108) at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:320) at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142) at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617) ``` ## How was this patch tested? Added a test in `DataFrameRangeSuite`. Author: Takeshi Yamamuro Closes #17717 from maropu/SPARK-20430. --- .../spark/sql/execution/basicPhysicalOperators.scala | 10 +++++----- .../org/apache/spark/sql/DataFrameRangeSuite.scala | 6 ++++++ 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index d3efa428a6db..64698d552757 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -331,11 +331,11 @@ case class SampleExec( case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) extends LeafExecNode with CodegenSupport { - def start: Long = range.start - def end: Long = range.end - def step: Long = range.step - def numSlices: Int = range.numSlices.getOrElse(sparkContext.defaultParallelism) - def numElements: BigInt = range.numElements + val start: Long = range.start + val end: Long = range.end + val step: Long = range.step + val numSlices: Int = range.numSlices.getOrElse(sparkContext.defaultParallelism) + val numElements: BigInt = range.numElements override val output: Seq[Attribute] = range.output diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala index 5e323c02b253..7b495656b93d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala @@ -185,6 +185,12 @@ class DataFrameRangeSuite extends QueryTest with SharedSQLContext with Eventuall } } } + + test("SPARK-20430 Initialize Range parameters in a driver side") { + withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false") { + checkAnswer(sql("SELECT * FROM range(3)"), Row(0) :: Row(1) :: Row(2) :: Nil) + } + } } object DataFrameRangeSuite { From 8765bc17d0439032d0378686c4f2b17df2432abc Mon Sep 17 00:00:00 2001 From: Michael Patterson Date: Sat, 22 Apr 2017 19:58:54 -0700 Subject: [PATCH 334/512] [SPARK-20132][DOCS] Add documentation for column string functions ## What changes were proposed in this pull request? Add docstrings to column.py for the Column functions `rlike`, `like`, `startswith`, and `endswith`. Pass these docstrings through `_bin_op` There may be a better place to put the docstrings. I put them immediately above the Column class. ## How was this patch tested? I ran `make html` on my local computer to remake the documentation, and verified that the html pages were displaying the docstrings correctly. I tried running `dev-tests`, and the formatting tests passed. However, my mvn build didn't work I think due to issues on my computer. These docstrings are my original work and free license. davies has done the most recent work reorganizing `_bin_op` Author: Michael Patterson Closes #17469 from map222/patterson-documentation. --- python/pyspark/sql/column.py | 70 ++++++++++++++++++++++++++++++++---- 1 file changed, 64 insertions(+), 6 deletions(-) diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py index ec05c18d4f06..46c1707cb6c3 100644 --- a/python/pyspark/sql/column.py +++ b/python/pyspark/sql/column.py @@ -250,11 +250,50 @@ def __iter__(self): raise TypeError("Column is not iterable") # string methods + _rlike_doc = """ + Return a Boolean :class:`Column` based on a regex match. + + :param other: an extended regex expression + + >>> df.filter(df.name.rlike('ice$')).collect() + [Row(age=2, name=u'Alice')] + """ + _like_doc = """ + Return a Boolean :class:`Column` based on a SQL LIKE match. + + :param other: a SQL LIKE pattern + + See :func:`rlike` for a regex version + + >>> df.filter(df.name.like('Al%')).collect() + [Row(age=2, name=u'Alice')] + """ + _startswith_doc = """ + Return a Boolean :class:`Column` based on a string match. + + :param other: string at end of line (do not use a regex `^`) + + >>> df.filter(df.name.startswith('Al')).collect() + [Row(age=2, name=u'Alice')] + >>> df.filter(df.name.startswith('^Al')).collect() + [] + """ + _endswith_doc = """ + Return a Boolean :class:`Column` based on matching end of string. + + :param other: string at end of line (do not use a regex `$`) + + >>> df.filter(df.name.endswith('ice')).collect() + [Row(age=2, name=u'Alice')] + >>> df.filter(df.name.endswith('ice$')).collect() + [] + """ + contains = _bin_op("contains") - rlike = _bin_op("rlike") - like = _bin_op("like") - startswith = _bin_op("startsWith") - endswith = _bin_op("endsWith") + rlike = ignore_unicode_prefix(_bin_op("rlike", _rlike_doc)) + like = ignore_unicode_prefix(_bin_op("like", _like_doc)) + startswith = ignore_unicode_prefix(_bin_op("startsWith", _startswith_doc)) + endswith = ignore_unicode_prefix(_bin_op("endsWith", _endswith_doc)) @ignore_unicode_prefix @since(1.3) @@ -303,8 +342,27 @@ def isin(self, *cols): desc = _unary_op("desc", "Returns a sort expression based on the" " descending order of the given column name.") - isNull = _unary_op("isNull", "True if the current expression is null.") - isNotNull = _unary_op("isNotNull", "True if the current expression is not null.") + _isNull_doc = """ + True if the current expression is null. Often combined with + :func:`DataFrame.filter` to select rows with null values. + + >>> from pyspark.sql import Row + >>> df2 = sc.parallelize([Row(name=u'Tom', height=80), Row(name=u'Alice', height=None)]).toDF() + >>> df2.filter(df2.height.isNull()).collect() + [Row(height=None, name=u'Alice')] + """ + _isNotNull_doc = """ + True if the current expression is null. Often combined with + :func:`DataFrame.filter` to select rows with non-null values. + + >>> from pyspark.sql import Row + >>> df2 = sc.parallelize([Row(name=u'Tom', height=80), Row(name=u'Alice', height=None)]).toDF() + >>> df2.filter(df2.height.isNotNull()).collect() + [Row(height=80, name=u'Tom')] + """ + + isNull = ignore_unicode_prefix(_unary_op("isNull", _isNull_doc)) + isNotNull = ignore_unicode_prefix(_unary_op("isNotNull", _isNotNull_doc)) @since(1.3) def alias(self, *alias, **kwargs): From 2eaf4f3fe3595ae341a3a5ce886b859992dea5b2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=83=AD=E5=B0=8F=E9=BE=99=2010207633?= Date: Sun, 23 Apr 2017 13:33:14 +0100 Subject: [PATCH 335/512] [SPARK-20385][WEB-UI] Submitted Time' field, the date format needs to be formatted, in running Drivers table or Completed Drivers table in master web ui. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? Submitted Time' field, the date format **needs to be formatted**, in running Drivers table or Completed Drivers table in master web ui. Before fix this problem e.g. Completed Drivers Submission ID **Submitted Time** Worker State Cores Memory Main Class driver-20170419145755-0005 **Wed Apr 19 14:57:55 CST 2017** worker-20170419145250-zdh120-40412 FAILED 1 1024.0 MB cn.zte.HdfsTest please see the attachment:https://issues.apache.org/jira/secure/attachment/12863977/before_fix.png After fix this problem e.g. Completed Drivers Submission ID **Submitted Time** Worker State Cores Memory Main Class driver-20170419145755-0006 **2017/04/19 16:01:25** worker-20170419145250-zdh120-40412 FAILED 1 1024.0 MB cn.zte.HdfsTest please see the attachment:https://issues.apache.org/jira/secure/attachment/12863976/after_fix.png 'Submitted Time' field, the date format **has been formatted**, in running Applications table or Completed Applicationstable in master web ui, **it is correct.** e.g. Running Applications Application ID Name Cores Memory per Executor **Submitted Time** User State Duration app-20170419160910-0000 (kill) SparkSQL::10.43.183.120 1 5.0 GB **2017/04/19 16:09:10** root RUNNING 53 s **Format after the time easier to observe, and consistent with the applications table,so I think it's worth fixing.** ## How was this patch tested? (Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) (If this patch involves UI changes, please attach a screenshot; otherwise, remove this) Please review http://spark.apache.org/contributing.html before opening a pull request. Author: 郭小龙 10207633 Author: guoxiaolong Author: guoxiaolongzte Closes #17682 from guoxiaolongzte/SPARK-20385. --- .../apache/spark/deploy/master/ui/ApplicationPage.scala | 2 +- .../org/apache/spark/deploy/master/ui/MasterPage.scala | 2 +- .../org/apache/spark/deploy/mesos/ui/DriverPage.scala | 4 ++-- .../apache/spark/deploy/mesos/ui/MesosClusterPage.scala | 8 ++++---- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala index 946a92882141..a8d721f3e0d4 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala @@ -83,7 +83,7 @@ private[ui] class ApplicationPage(parent: MasterWebUI) extends WebUIPage("app") Executor Memory: {Utils.megabytesToString(app.desc.memoryPerExecutorMB)}
  • -
  • Submit Date: {app.submitDate}
  • +
  • Submit Date: {UIUtils.formatDate(app.submitDate)}
  • State: {app.state}
  • { if (!app.isFinished) { diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala index e722a24d4a89..9351c72094e3 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala @@ -252,7 +252,7 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { } {driver.id} {killLink} - {driver.submitDate} + {UIUtils.formatDate(driver.submitDate)} {driver.worker.map(w => if (w.isAlive()) { diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/DriverPage.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/DriverPage.scala index cd98110ddcc0..127fadabcce5 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/DriverPage.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/DriverPage.scala @@ -101,7 +101,7 @@ private[ui] class DriverPage(parent: MesosClusterUI) extends WebUIPage("driver") Launch Time - {state.startDate} + {UIUtils.formatDate(state.startDate)} Finish Time @@ -154,7 +154,7 @@ private[ui] class DriverPage(parent: MesosClusterUI) extends WebUIPage("driver") Memory{driver.mem} - Submitted{driver.submissionDate} + Submitted{UIUtils.formatDate(driver.submissionDate)} Supervise{driver.supervise} diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterPage.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterPage.scala index 13ba7d311e57..c9107c3e73d3 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterPage.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterPage.scala @@ -68,7 +68,7 @@ private[mesos] class MesosClusterPage(parent: MesosClusterUI) extends WebUIPage( val id = submission.submissionId {id} - {submission.submissionDate} + {UIUtils.formatDate(submission.submissionDate)} {submission.command.mainClass} cpus: {submission.cores}, mem: {submission.mem} @@ -88,10 +88,10 @@ private[mesos] class MesosClusterPage(parent: MesosClusterUI) extends WebUIPage( {id} {historyCol} - {state.driverDescription.submissionDate} + {UIUtils.formatDate(state.driverDescription.submissionDate)} {state.driverDescription.command.mainClass} cpus: {state.driverDescription.cores}, mem: {state.driverDescription.mem} - {state.startDate} + {UIUtils.formatDate(state.startDate)} {state.slaveId.getValue} {stateString(state.mesosTaskStatus)} @@ -101,7 +101,7 @@ private[mesos] class MesosClusterPage(parent: MesosClusterUI) extends WebUIPage( val id = submission.submissionId {id} - {submission.submissionDate} + {UIUtils.formatDate(submission.submissionDate)} {submission.command.mainClass} {submission.retryState.get.lastFailureStatus} {submission.retryState.get.nextRetry} From e9f97154bc4af60376a550238315d7fc57099f9c Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Mon, 24 Apr 2017 09:34:38 +0100 Subject: [PATCH 336/512] [BUILD] Close stale PRs ## What changes were proposed in this pull request? This pr proposed to close stale PRs. Currently, we have 400+ open PRs and there are some stale PRs whose JIRA tickets have been already closed and whose JIRA tickets does not exist (also, they seem not to be minor issues). // Open PRs whose JIRA tickets have been already closed Closes #11785 Closes #13027 Closes #13614 Closes #13761 Closes #15197 Closes #14006 Closes #12576 Closes #15447 Closes #13259 Closes #15616 Closes #14473 Closes #16638 Closes #16146 Closes #17269 Closes #17313 Closes #17418 Closes #17485 Closes #17551 Closes #17463 Closes #17625 // Open PRs whose JIRA tickets does not exist and they are not minor issues Closes #10739 Closes #15193 Closes #15344 Closes #14804 Closes #16993 Closes #17040 Closes #15180 Closes #17238 ## How was this patch tested? N/A Author: Takeshi Yamamuro Closes #17734 from maropu/resolved_pr. From 776a2c0e91dfea170ea1c489118e1d42c4121f35 Mon Sep 17 00:00:00 2001 From: Xiao Li Date: Mon, 24 Apr 2017 17:21:42 +0800 Subject: [PATCH 337/512] [SPARK-20439][SQL] Fix Catalog API listTables and getTable when failed to fetch table metadata ### What changes were proposed in this pull request? `spark.catalog.listTables` and `spark.catalog.getTable` does not work if we are unable to retrieve table metadata due to any reason (e.g., table serde class is not accessible or the table type is not accepted by Spark SQL). After this PR, the APIs still return the corresponding Table without the description and tableType) ### How was this patch tested? Added a test case Author: Xiao Li Closes #17730 from gatorsmile/listTables. --- .../spark/sql/internal/CatalogImpl.scala | 28 +++++++++++++++---- .../sql/hive/execution/HiveDDLSuite.scala | 8 ++++++ 2 files changed, 31 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala index aebb663df5c9..0b8e53868c99 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.internal import scala.reflect.runtime.universe.TypeTag +import scala.util.control.NonFatal import org.apache.spark.annotation.Experimental import org.apache.spark.sql._ @@ -98,14 +99,27 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { CatalogImpl.makeDataset(tables, sparkSession) } + /** + * Returns a Table for the given table/view or temporary view. + * + * Note that this function requires the table already exists in the Catalog. + * + * If the table metadata retrieval failed due to any reason (e.g., table serde class + * is not accessible or the table type is not accepted by Spark SQL), this function + * still returns the corresponding Table without the description and tableType) + */ private def makeTable(tableIdent: TableIdentifier): Table = { - val metadata = sessionCatalog.getTempViewOrPermanentTableMetadata(tableIdent) + val metadata = try { + Some(sessionCatalog.getTempViewOrPermanentTableMetadata(tableIdent)) + } catch { + case NonFatal(_) => None + } val isTemp = sessionCatalog.isTemporaryTable(tableIdent) new Table( name = tableIdent.table, - database = metadata.identifier.database.orNull, - description = metadata.comment.orNull, - tableType = if (isTemp) "TEMPORARY" else metadata.tableType.name, + database = metadata.map(_.identifier.database).getOrElse(tableIdent.database).orNull, + description = metadata.map(_.comment.orNull).orNull, + tableType = if (isTemp) "TEMPORARY" else metadata.map(_.tableType.name).orNull, isTemporary = isTemp) } @@ -197,7 +211,11 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { * `AnalysisException` when no `Table` can be found. */ override def getTable(dbName: String, tableName: String): Table = { - makeTable(TableIdentifier(tableName, Option(dbName))) + if (tableExists(dbName, tableName)) { + makeTable(TableIdentifier(tableName, Option(dbName))) + } else { + throw new AnalysisException(s"Table or view '$tableName' not found in database '$dbName'") + } } /** diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala index 3906968aaff1..16a99321bad3 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -1197,6 +1197,14 @@ class HiveDDLSuite s"CREATE INDEX $indexName ON TABLE $tabName (a) AS 'COMPACT' WITH DEFERRED REBUILD") val indexTabName = spark.sessionState.catalog.listTables("default", s"*$indexName*").head.table + + // Even if index tables exist, listTables and getTable APIs should still work + checkAnswer( + spark.catalog.listTables().toDF(), + Row(indexTabName, "default", null, null, false) :: + Row(tabName, "default", null, "MANAGED", false) :: Nil) + assert(spark.catalog.getTable("default", indexTabName).name === indexTabName) + intercept[TableAlreadyExistsException] { sql(s"CREATE TABLE $indexTabName(b int)") } From 90264aced7cfdf265636517b91e5d1324fe60112 Mon Sep 17 00:00:00 2001 From: "wm624@hotmail.com" Date: Mon, 24 Apr 2017 23:43:06 +0800 Subject: [PATCH 338/512] [SPARK-18901][ML] Require in LR LogisticAggregator is redundant ## What changes were proposed in this pull request? In MultivariateOnlineSummarizer, `add` and `merge` have check for weights and feature sizes. The checks in LR are redundant, which are removed from this PR. ## How was this patch tested? Existing tests. Author: wm624@hotmail.com Closes #17478 from wangmiao1981/logit. --- .../apache/spark/ml/classification/LogisticRegression.scala | 5 ----- 1 file changed, 5 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index bc8154692e52..44b3478e0c3d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -1571,9 +1571,6 @@ private class LogisticAggregator( */ def add(instance: Instance): this.type = { instance match { case Instance(label, weight, features) => - require(numFeatures == features.size, s"Dimensions mismatch when adding new instance." + - s" Expecting $numFeatures but got ${features.size}.") - require(weight >= 0.0, s"instance weight, $weight has to be >= 0.0") if (weight == 0.0) return this @@ -1596,8 +1593,6 @@ private class LogisticAggregator( * @return This LogisticAggregator object. */ def merge(other: LogisticAggregator): this.type = { - require(numFeatures == other.numFeatures, s"Dimensions mismatch when merging with another " + - s"LogisticAggregator. Expecting $numFeatures but got ${other.numFeatures}.") if (other.weightSum != 0.0) { weightSum += other.weightSum From 8a272ddc9d2359a724aa89ae2f8de121a4aa7ac2 Mon Sep 17 00:00:00 2001 From: zero323 Date: Mon, 24 Apr 2017 10:56:57 -0700 Subject: [PATCH 339/512] [SPARK-20438][R] SparkR wrappers for split and repeat ## What changes were proposed in this pull request? Add wrappers for `o.a.s.sql.functions`: - `split` as `split_string` - `repeat` as `repeat_string` ## How was this patch tested? Existing tests, additional unit tests, `check-cran.sh` Author: zero323 Closes #17729 from zero323/SPARK-20438. --- R/pkg/NAMESPACE | 2 + R/pkg/R/functions.R | 58 +++++++++++++++++++++++ R/pkg/R/generics.R | 8 ++++ R/pkg/inst/tests/testthat/test_sparkSQL.R | 34 +++++++++++++ 4 files changed, 102 insertions(+) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index e804e30e14b8..95d5cc6d1c78 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -300,6 +300,7 @@ exportMethods("%in%", "rank", "regexp_extract", "regexp_replace", + "repeat_string", "reverse", "rint", "rlike", @@ -323,6 +324,7 @@ exportMethods("%in%", "sort_array", "soundex", "spark_partition_id", + "split_string", "stddev", "stddev_pop", "stddev_samp", diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index e7decb91867b..752e4c5c7189 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -3745,3 +3745,61 @@ setMethod("collect_set", jc <- callJStatic("org.apache.spark.sql.functions", "collect_set", x@jc) column(jc) }) + +#' split_string +#' +#' Splits string on regular expression. +#' +#' Equivalent to \code{split} SQL function +#' +#' @param x Column to compute on +#' @param pattern Java regular expression +#' +#' @rdname split_string +#' @family string_funcs +#' @aliases split_string,Column-method +#' @export +#' @examples \dontrun{ +#' df <- read.text("README.md") +#' +#' head(select(df, split_string(df$value, "\\s+"))) +#' +#' # This is equivalent to the following SQL expression +#' head(selectExpr(df, "split(value, '\\\\s+')")) +#' } +#' @note split_string 2.3.0 +setMethod("split_string", + signature(x = "Column", pattern = "character"), + function(x, pattern) { + jc <- callJStatic("org.apache.spark.sql.functions", "split", x@jc, pattern) + column(jc) + }) + +#' repeat_string +#' +#' Repeats string n times. +#' +#' Equivalent to \code{repeat} SQL function +#' +#' @param x Column to compute on +#' @param n Number of repetitions +#' +#' @rdname repeat_string +#' @family string_funcs +#' @aliases repeat_string,Column-method +#' @export +#' @examples \dontrun{ +#' df <- read.text("README.md") +#' +#' first(select(df, repeat_string(df$value, 3))) +#' +#' # This is equivalent to the following SQL expression +#' first(selectExpr(df, "repeat(value, 3)")) +#' } +#' @note repeat_string 2.3.0 +setMethod("repeat_string", + signature(x = "Column", n = "numeric"), + function(x, n) { + jc <- callJStatic("org.apache.spark.sql.functions", "repeat", x@jc, numToInt(n)) + column(jc) + }) diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 61d248ebd2e3..5e7a1c60c2b3 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -1192,6 +1192,10 @@ setGeneric("regexp_extract", function(x, pattern, idx) { standardGeneric("regexp setGeneric("regexp_replace", function(x, pattern, replacement) { standardGeneric("regexp_replace") }) +#' @rdname repeat_string +#' @export +setGeneric("repeat_string", function(x, n) { standardGeneric("repeat_string") }) + #' @rdname reverse #' @export setGeneric("reverse", function(x) { standardGeneric("reverse") }) @@ -1257,6 +1261,10 @@ setGeneric("skewness", function(x) { standardGeneric("skewness") }) #' @export setGeneric("sort_array", function(x, asc = TRUE) { standardGeneric("sort_array") }) +#' @rdname split_string +#' @export +setGeneric("split_string", function(x, pattern) { standardGeneric("split_string") }) + #' @rdname soundex #' @export setGeneric("soundex", function(x) { standardGeneric("soundex") }) diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index bf2093fdc475..c21ba2f1a138 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -1546,6 +1546,40 @@ test_that("string operators", { expect_equal(collect(select(df3, substring_index(df3$a, ".", 2)))[1, 1], "a.b") expect_equal(collect(select(df3, substring_index(df3$a, ".", -3)))[1, 1], "b.c.d") expect_equal(collect(select(df3, translate(df3$a, "bc", "12")))[1, 1], "a.1.2.d") + + l4 <- list(list(a = "a.b@c.d 1\\b")) + df4 <- createDataFrame(l4) + expect_equal( + collect(select(df4, split_string(df4$a, "\\s+")))[1, 1], + list(list("a.b@c.d", "1\\b")) + ) + expect_equal( + collect(select(df4, split_string(df4$a, "\\.")))[1, 1], + list(list("a", "b@c", "d 1\\b")) + ) + expect_equal( + collect(select(df4, split_string(df4$a, "@")))[1, 1], + list(list("a.b", "c.d 1\\b")) + ) + expect_equal( + collect(select(df4, split_string(df4$a, "\\\\")))[1, 1], + list(list("a.b@c.d 1", "b")) + ) + + l5 <- list(list(a = "abc")) + df5 <- createDataFrame(l5) + expect_equal( + collect(select(df5, repeat_string(df5$a, 1L)))[1, 1], + "abc" + ) + expect_equal( + collect(select(df5, repeat_string(df5$a, 3)))[1, 1], + "abcabcabc" + ) + expect_equal( + collect(select(df5, repeat_string(df5$a, -1)))[1, 1], + "" + ) }) test_that("date functions on a DataFrame", { From 5280d93e6ecec7327e7fcd3d8d1cb90e01e774fc Mon Sep 17 00:00:00 2001 From: jerryshao Date: Mon, 24 Apr 2017 18:18:59 -0700 Subject: [PATCH 340/512] [SPARK-20239][CORE] Improve HistoryServer's ACL mechanism ## What changes were proposed in this pull request? Current SHS (Spark History Server) two different ACLs: * ACL of base URL, it is controlled by "spark.acls.enabled" or "spark.ui.acls.enabled", and with this enabled, only user configured with "spark.admin.acls" (or group) or "spark.ui.view.acls" (or group), or the user who started SHS could list all the applications, otherwise none of them can be listed. This will also affect REST APIs which listing the summary of all apps and one app. * Per application ACL. This is controlled by "spark.history.ui.acls.enabled". With this enabled only history admin user and user/group who ran this app can access the details of this app. With this two ACLs, we may encounter several unexpected behaviors: 1. if base URL's ACL (`spark.acls.enable`) is enabled but user A has no view permission. User "A" cannot see the app list but could still access details of it's own app. 2. if ACLs of base URL (`spark.acls.enable`) is disabled, then user "A" could download any application's event log, even it is not run by user "A". 3. The changes of Live UI's ACL will affect History UI's ACL which share the same conf file. The unexpected behaviors is mainly because we have two different ACLs, ideally we should have only one to manage all. So to improve SHS's ACL mechanism, here in this PR proposed to: 1. Disable "spark.acls.enable" and only use "spark.history.ui.acls.enable" for history server. 2. Check permission for event-log download REST API. With this PR: 1. Admin user could see/download the list of all applications, as well as application details. 2. Normal user could see the list of all applications, but can only download and check the details of applications accessible to him. ## How was this patch tested? New UTs are added, also verified in real cluster. CC tgravescs vanzin please help to review, this PR changes the semantics you did previously. Thanks a lot. Author: jerryshao Closes #17582 from jerryshao/SPARK-20239. --- .../history/ApplicationHistoryProvider.scala | 4 ++-- .../spark/deploy/history/HistoryServer.scala | 8 ++++++++ .../spark/status/api/v1/ApiRootResource.scala | 18 +++++++++++++++--- .../deploy/history/HistoryServerSuite.scala | 14 ++++++++------ 4 files changed, 33 insertions(+), 11 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala index d7d82800b8b5..6d8758a3d3b1 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala @@ -86,7 +86,7 @@ private[history] abstract class ApplicationHistoryProvider { * @return Count of application event logs that are currently under process */ def getEventLogsUnderProcess(): Int = { - return 0; + 0 } /** @@ -95,7 +95,7 @@ private[history] abstract class ApplicationHistoryProvider { * @return 0 if this is undefined or unsupported, otherwise the last updated time in millis */ def getLastUpdatedTime(): Long = { - return 0; + 0 } /** diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala index 54f39f7620e5..d9c8fda99ef9 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala @@ -301,6 +301,14 @@ object HistoryServer extends Logging { logDebug(s"Clearing ${SecurityManager.SPARK_AUTH_CONF}") config.set(SecurityManager.SPARK_AUTH_CONF, "false") } + + if (config.getBoolean("spark.acls.enable", config.getBoolean("spark.ui.acls.enable", false))) { + logInfo("Either spark.acls.enable or spark.ui.acls.enable is configured, clearing it and " + + "only using spark.history.ui.acl.enable") + config.set("spark.acls.enable", "false") + config.set("spark.ui.acls.enable", "false") + } + new SecurityManager(config) } diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala index 00f918c09c66..f17b63775482 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala @@ -184,14 +184,27 @@ private[v1] class ApiRootResource extends ApiRequestContext { @Path("applications/{appId}/logs") def getEventLogs( @PathParam("appId") appId: String): EventLogDownloadResource = { - new EventLogDownloadResource(uiRoot, appId, None) + try { + // withSparkUI will throw NotFoundException if attemptId exists for this application. + // So we need to try again with attempt id "1". + withSparkUI(appId, None) { _ => + new EventLogDownloadResource(uiRoot, appId, None) + } + } catch { + case _: NotFoundException => + withSparkUI(appId, Some("1")) { _ => + new EventLogDownloadResource(uiRoot, appId, None) + } + } } @Path("applications/{appId}/{attemptId}/logs") def getEventLogs( @PathParam("appId") appId: String, @PathParam("attemptId") attemptId: String): EventLogDownloadResource = { - new EventLogDownloadResource(uiRoot, appId, Some(attemptId)) + withSparkUI(appId, Some(attemptId)) { _ => + new EventLogDownloadResource(uiRoot, appId, Some(attemptId)) + } } @Path("version") @@ -291,7 +304,6 @@ private[v1] trait ApiRequestContext { case None => throw new NotFoundException("no such app: " + appId) } } - } private[v1] class ForbiddenException(msg: String) extends WebApplicationException( diff --git a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala index 764156c3edc4..95acb9a54440 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala @@ -565,13 +565,12 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers assert(jobcount === getNumJobs("/jobs")) // no need to retain the test dir now the tests complete - logDir.deleteOnExit(); - + logDir.deleteOnExit() } test("ui and api authorization checks") { - val appId = "app-20161115172038-0000" - val owner = "jose" + val appId = "local-1430917381535" + val owner = "irashid" val admin = "root" val other = "alice" @@ -590,8 +589,11 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers val port = server.boundPort val testUrls = Seq( - s"http://localhost:$port/api/v1/applications/$appId/jobs", - s"http://localhost:$port/history/$appId/jobs/") + s"http://localhost:$port/api/v1/applications/$appId/1/jobs", + s"http://localhost:$port/history/$appId/1/jobs/", + s"http://localhost:$port/api/v1/applications/$appId/logs", + s"http://localhost:$port/api/v1/applications/$appId/1/logs", + s"http://localhost:$port/api/v1/applications/$appId/2/logs") tests.foreach { case (user, expectedCode) => testUrls.foreach { url => From f44c8a843ca512b319f099477415bc13eca2e373 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 24 Apr 2017 21:48:04 -0700 Subject: [PATCH 341/512] [SPARK-20453] Bump master branch version to 2.3.0-SNAPSHOT This patch bumps the master branch version to `2.3.0-SNAPSHOT`. Author: Josh Rosen Closes #17753 from JoshRosen/SPARK-20453. --- assembly/pom.xml | 2 +- common/network-common/pom.xml | 2 +- common/network-shuffle/pom.xml | 2 +- common/network-yarn/pom.xml | 2 +- common/sketch/pom.xml | 2 +- common/tags/pom.xml | 2 +- common/unsafe/pom.xml | 2 +- core/pom.xml | 2 +- docs/_config.yml | 4 ++-- examples/pom.xml | 2 +- external/docker-integration-tests/pom.xml | 2 +- external/flume-assembly/pom.xml | 2 +- external/flume-sink/pom.xml | 2 +- external/flume/pom.xml | 2 +- external/kafka-0-10-assembly/pom.xml | 2 +- external/kafka-0-10-sql/pom.xml | 2 +- external/kafka-0-10/pom.xml | 2 +- external/kafka-0-8-assembly/pom.xml | 2 +- external/kafka-0-8/pom.xml | 2 +- external/kinesis-asl-assembly/pom.xml | 2 +- external/kinesis-asl/pom.xml | 2 +- external/spark-ganglia-lgpl/pom.xml | 2 +- graphx/pom.xml | 2 +- launcher/pom.xml | 2 +- mllib-local/pom.xml | 2 +- mllib/pom.xml | 2 +- pom.xml | 2 +- project/MimaExcludes.scala | 5 +++++ repl/pom.xml | 2 +- resource-managers/mesos/pom.xml | 2 +- resource-managers/yarn/pom.xml | 2 +- sql/catalyst/pom.xml | 2 +- sql/core/pom.xml | 2 +- sql/hive-thriftserver/pom.xml | 2 +- sql/hive/pom.xml | 2 +- streaming/pom.xml | 2 +- tools/pom.xml | 2 +- 37 files changed, 42 insertions(+), 37 deletions(-) diff --git a/assembly/pom.xml b/assembly/pom.xml index 9d8607d9137c..742a4a1531e7 100644 --- a/assembly/pom.xml +++ b/assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../pom.xml diff --git a/common/network-common/pom.xml b/common/network-common/pom.xml index 8657af744c06..066970f24205 100644 --- a/common/network-common/pom.xml +++ b/common/network-common/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../../pom.xml diff --git a/common/network-shuffle/pom.xml b/common/network-shuffle/pom.xml index 24c10fb1ddb9..2de882adcb58 100644 --- a/common/network-shuffle/pom.xml +++ b/common/network-shuffle/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../../pom.xml diff --git a/common/network-yarn/pom.xml b/common/network-yarn/pom.xml index 5e5a80bd4446..a8488d8d1b70 100644 --- a/common/network-yarn/pom.xml +++ b/common/network-yarn/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../../pom.xml diff --git a/common/sketch/pom.xml b/common/sketch/pom.xml index 1356c4723b66..6b81fc2b2b04 100644 --- a/common/sketch/pom.xml +++ b/common/sketch/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../../pom.xml diff --git a/common/tags/pom.xml b/common/tags/pom.xml index 9345dc8f0cc4..f7e586ee777e 100644 --- a/common/tags/pom.xml +++ b/common/tags/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../../pom.xml diff --git a/common/unsafe/pom.xml b/common/unsafe/pom.xml index f03a4da5e715..680d0413b161 100644 --- a/common/unsafe/pom.xml +++ b/common/unsafe/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../../pom.xml diff --git a/core/pom.xml b/core/pom.xml index 24ce36deeb16..7f245b5b6384 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../pom.xml diff --git a/docs/_config.yml b/docs/_config.yml index 83bb30598d15..21255ef7a5c4 100644 --- a/docs/_config.yml +++ b/docs/_config.yml @@ -14,8 +14,8 @@ include: # These allow the documentation to be updated with newer releases # of Spark, Scala, and Mesos. -SPARK_VERSION: 2.2.0-SNAPSHOT -SPARK_VERSION_SHORT: 2.2.0 +SPARK_VERSION: 2.3.0-SNAPSHOT +SPARK_VERSION_SHORT: 2.3.0 SCALA_BINARY_VERSION: "2.11" SCALA_VERSION: "2.11.7" MESOS_VERSION: 1.0.0 diff --git a/examples/pom.xml b/examples/pom.xml index 91c2e81ebed2..e674e799f24a 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../pom.xml diff --git a/external/docker-integration-tests/pom.xml b/external/docker-integration-tests/pom.xml index 8948df2da89e..0fa87a697454 100644 --- a/external/docker-integration-tests/pom.xml +++ b/external/docker-integration-tests/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../../pom.xml diff --git a/external/flume-assembly/pom.xml b/external/flume-assembly/pom.xml index f8ef8a991316..71016bc645ca 100644 --- a/external/flume-assembly/pom.xml +++ b/external/flume-assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../../pom.xml diff --git a/external/flume-sink/pom.xml b/external/flume-sink/pom.xml index 6d547c46d6a2..12630840e79d 100644 --- a/external/flume-sink/pom.xml +++ b/external/flume-sink/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../../pom.xml diff --git a/external/flume/pom.xml b/external/flume/pom.xml index 46901d64eda9..87a09642405a 100644 --- a/external/flume/pom.xml +++ b/external/flume/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../../pom.xml diff --git a/external/kafka-0-10-assembly/pom.xml b/external/kafka-0-10-assembly/pom.xml index 295142cbfdff..75df886ca44f 100644 --- a/external/kafka-0-10-assembly/pom.xml +++ b/external/kafka-0-10-assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../../pom.xml diff --git a/external/kafka-0-10-sql/pom.xml b/external/kafka-0-10-sql/pom.xml index 6cf448e65e8b..557d27296345 100644 --- a/external/kafka-0-10-sql/pom.xml +++ b/external/kafka-0-10-sql/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../../pom.xml diff --git a/external/kafka-0-10/pom.xml b/external/kafka-0-10/pom.xml index 88499240cd56..6c98cb04fcfa 100644 --- a/external/kafka-0-10/pom.xml +++ b/external/kafka-0-10/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../../pom.xml diff --git a/external/kafka-0-8-assembly/pom.xml b/external/kafka-0-8-assembly/pom.xml index 3fedd9eda195..f9c2dcb38dc0 100644 --- a/external/kafka-0-8-assembly/pom.xml +++ b/external/kafka-0-8-assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../../pom.xml diff --git a/external/kafka-0-8/pom.xml b/external/kafka-0-8/pom.xml index 8368a1f12218..849c8b465f99 100644 --- a/external/kafka-0-8/pom.xml +++ b/external/kafka-0-8/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../../pom.xml diff --git a/external/kinesis-asl-assembly/pom.xml b/external/kinesis-asl-assembly/pom.xml index 90bb0e4987c8..48783d65826a 100644 --- a/external/kinesis-asl-assembly/pom.xml +++ b/external/kinesis-asl-assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../../pom.xml diff --git a/external/kinesis-asl/pom.xml b/external/kinesis-asl/pom.xml index daa79e79163b..40a751a652fa 100644 --- a/external/kinesis-asl/pom.xml +++ b/external/kinesis-asl/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../../pom.xml diff --git a/external/spark-ganglia-lgpl/pom.xml b/external/spark-ganglia-lgpl/pom.xml index 7da27817ebaf..36d555066b18 100644 --- a/external/spark-ganglia-lgpl/pom.xml +++ b/external/spark-ganglia-lgpl/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../../pom.xml diff --git a/graphx/pom.xml b/graphx/pom.xml index 8df33660ea9d..cb30e4a4af4b 100644 --- a/graphx/pom.xml +++ b/graphx/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../pom.xml diff --git a/launcher/pom.xml b/launcher/pom.xml index 025cd84f20f0..e9b46c4cf0ff 100644 --- a/launcher/pom.xml +++ b/launcher/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../pom.xml diff --git a/mllib-local/pom.xml b/mllib-local/pom.xml index 663f7fb0b010..043d13609fd2 100644 --- a/mllib-local/pom.xml +++ b/mllib-local/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../pom.xml diff --git a/mllib/pom.xml b/mllib/pom.xml index 82f840b0fc26..572670dc11b4 100644 --- a/mllib/pom.xml +++ b/mllib/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../pom.xml diff --git a/pom.xml b/pom.xml index c1174593c192..a65692e0d131 100644 --- a/pom.xml +++ b/pom.xml @@ -26,7 +26,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT pom Spark Project Parent POM http://spark.apache.org/ diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index feae76a087de..dbf933f28a78 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -34,6 +34,10 @@ import com.typesafe.tools.mima.core.ProblemFilters._ */ object MimaExcludes { + // Exclude rules for 2.3.x + lazy val v23excludes = v22excludes ++ Seq( + ) + // Exclude rules for 2.2.x lazy val v22excludes = v21excludes ++ Seq( // [SPARK-19652][UI] Do auth checks for REST API access. @@ -1003,6 +1007,7 @@ object MimaExcludes { } def excludes(version: String) = version match { + case v if v.startsWith("2.3") => v23excludes case v if v.startsWith("2.2") => v22excludes case v if v.startsWith("2.1") => v21excludes case v if v.startsWith("2.0") => v20excludes diff --git a/repl/pom.xml b/repl/pom.xml index a256ae3b8418..6d133a3cfff7 100644 --- a/repl/pom.xml +++ b/repl/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../pom.xml diff --git a/resource-managers/mesos/pom.xml b/resource-managers/mesos/pom.xml index 03846d9f5a3b..20b53f2d8f98 100644 --- a/resource-managers/mesos/pom.xml +++ b/resource-managers/mesos/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../../pom.xml diff --git a/resource-managers/yarn/pom.xml b/resource-managers/yarn/pom.xml index a1b641c8eeb8..71d4ad681e16 100644 --- a/resource-managers/yarn/pom.xml +++ b/resource-managers/yarn/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../../pom.xml diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml index 765c92b8d3b9..8d80f8eca5db 100644 --- a/sql/catalyst/pom.xml +++ b/sql/catalyst/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../../pom.xml diff --git a/sql/core/pom.xml b/sql/core/pom.xml index b203f31a76f0..e170133f0f0b 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../../pom.xml diff --git a/sql/hive-thriftserver/pom.xml b/sql/hive-thriftserver/pom.xml index 9c879218ddc0..a5a8e2640586 100644 --- a/sql/hive-thriftserver/pom.xml +++ b/sql/hive-thriftserver/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../../pom.xml diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml index 0f249d7d5935..09dcc4055e00 100644 --- a/sql/hive/pom.xml +++ b/sql/hive/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../../pom.xml diff --git a/streaming/pom.xml b/streaming/pom.xml index de1be9c13e05..fea882ad1123 100644 --- a/streaming/pom.xml +++ b/streaming/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../pom.xml diff --git a/tools/pom.xml b/tools/pom.xml index 938ba2f6ac20..7ba4dc9842f1 100644 --- a/tools/pom.xml +++ b/tools/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../pom.xml From 31345fde82ada1f8bb12807b250b04726a1f6aa6 Mon Sep 17 00:00:00 2001 From: Sameer Agarwal Date: Tue, 25 Apr 2017 13:05:20 +0800 Subject: [PATCH 342/512] [SPARK-20451] Filter out nested mapType datatypes from sort order in randomSplit ## What changes were proposed in this pull request? In `randomSplit`, It is possible that the underlying dataset doesn't guarantee the ordering of rows in its constituent partitions each time a split is materialized which could result in overlapping splits. To prevent this, as part of SPARK-12662, we explicitly sort each input partition to make the ordering deterministic. Given that `MapTypes` cannot be sorted this patch explicitly prunes them out from the sort order. Additionally, if the resulting sort order is empty, this patch then materializes the dataset to guarantee determinism. ## How was this patch tested? Extended `randomSplit on reordered partitions` in `DataFrameStatSuite` to also test for dataframes with mapTypes nested mapTypes. Author: Sameer Agarwal Closes #17751 from sameeragarwal/randomsplit2. --- .../scala/org/apache/spark/sql/Dataset.scala | 18 +++++--- .../apache/spark/sql/DataFrameStatSuite.scala | 43 ++++++++++++------- 2 files changed, 41 insertions(+), 20 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index c6dcd93bbda6..06dd5500718d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -1726,15 +1726,23 @@ class Dataset[T] private[sql]( // It is possible that the underlying dataframe doesn't guarantee the ordering of rows in its // constituent partitions each time a split is materialized which could result in // overlapping splits. To prevent this, we explicitly sort each input partition to make the - // ordering deterministic. - // MapType cannot be sorted. - val sorted = Sort(logicalPlan.output.filterNot(_.dataType.isInstanceOf[MapType]) - .map(SortOrder(_, Ascending)), global = false, logicalPlan) + // ordering deterministic. Note that MapTypes cannot be sorted and are explicitly pruned out + // from the sort order. + val sortOrder = logicalPlan.output + .filter(attr => RowOrdering.isOrderable(attr.dataType)) + .map(SortOrder(_, Ascending)) + val plan = if (sortOrder.nonEmpty) { + Sort(sortOrder, global = false, logicalPlan) + } else { + // SPARK-12662: If sort order is empty, we materialize the dataset to guarantee determinism + cache() + logicalPlan + } val sum = weights.sum val normalizedCumWeights = weights.map(_ / sum).scanLeft(0.0d)(_ + _) normalizedCumWeights.sliding(2).map { x => new Dataset[T]( - sparkSession, Sample(x(0), x(1), withReplacement = false, seed, sorted)(), encoder) + sparkSession, Sample(x(0), x(1), withReplacement = false, seed, plan)(), encoder) }.toArray } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index 97890a035a62..dd118f88e3bb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -68,25 +68,38 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext { } test("randomSplit on reordered partitions") { - // This test ensures that randomSplit does not create overlapping splits even when the - // underlying dataframe (such as the one below) doesn't guarantee a deterministic ordering of - // rows in each partition. - val data = - sparkContext.parallelize(1 to 600, 2).mapPartitions(scala.util.Random.shuffle(_)).toDF("id") - val splits = data.randomSplit(Array[Double](2, 3), seed = 1) - assert(splits.length == 2, "wrong number of splits") + def testNonOverlappingSplits(data: DataFrame): Unit = { + val splits = data.randomSplit(Array[Double](2, 3), seed = 1) + assert(splits.length == 2, "wrong number of splits") + + // Verify that the splits span the entire dataset + assert(splits.flatMap(_.collect()).toSet == data.collect().toSet) - // Verify that the splits span the entire dataset - assert(splits.flatMap(_.collect()).toSet == data.collect().toSet) + // Verify that the splits don't overlap + assert(splits(0).collect().toSeq.intersect(splits(1).collect().toSeq).isEmpty) - // Verify that the splits don't overlap - assert(splits(0).intersect(splits(1)).collect().isEmpty) + // Verify that the results are deterministic across multiple runs + val firstRun = splits.toSeq.map(_.collect().toSeq) + val secondRun = data.randomSplit(Array[Double](2, 3), seed = 1).toSeq.map(_.collect().toSeq) + assert(firstRun == secondRun) + } - // Verify that the results are deterministic across multiple runs - val firstRun = splits.toSeq.map(_.collect().toSeq) - val secondRun = data.randomSplit(Array[Double](2, 3), seed = 1).toSeq.map(_.collect().toSeq) - assert(firstRun == secondRun) + // This test ensures that randomSplit does not create overlapping splits even when the + // underlying dataframe (such as the one below) doesn't guarantee a deterministic ordering of + // rows in each partition. + val dataWithInts = sparkContext.parallelize(1 to 600, 2) + .mapPartitions(scala.util.Random.shuffle(_)).toDF("int") + val dataWithMaps = sparkContext.parallelize(1 to 600, 2) + .map(i => (i, Map(i -> i.toString))) + .mapPartitions(scala.util.Random.shuffle(_)).toDF("int", "map") + val dataWithArrayOfMaps = sparkContext.parallelize(1 to 600, 2) + .map(i => (i, Array(Map(i -> i.toString)))) + .mapPartitions(scala.util.Random.shuffle(_)).toDF("int", "arrayOfMaps") + + testNonOverlappingSplits(dataWithInts) + testNonOverlappingSplits(dataWithMaps) + testNonOverlappingSplits(dataWithArrayOfMaps) } test("pearson correlation") { From c8f1219510f469935aa9ff0b1c92cfe20372377c Mon Sep 17 00:00:00 2001 From: Armin Braun Date: Tue, 25 Apr 2017 09:13:50 +0100 Subject: [PATCH 343/512] [SPARK-20455][DOCS] Fix Broken Docker IT Docs ## What changes were proposed in this pull request? Just added the Maven `test`goal. ## How was this patch tested? No test needed, just a trivial documentation fix. Author: Armin Braun Closes #17756 from original-brownbear/SPARK-20455. --- docs/building-spark.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/building-spark.md b/docs/building-spark.md index e99b70f7a8b4..0f551bc66b8c 100644 --- a/docs/building-spark.md +++ b/docs/building-spark.md @@ -232,7 +232,7 @@ Once installed, the `docker` service needs to be started, if not already running On Linux, this can be done by `sudo service docker start`. ./build/mvn install -DskipTests - ./build/mvn -Pdocker-integration-tests -pl :spark-docker-integration-tests_2.11 + ./build/mvn test -Pdocker-integration-tests -pl :spark-docker-integration-tests_2.11 or From 0bc7a90210aad9025c1e1bdc99f8e723c1bf0fbf Mon Sep 17 00:00:00 2001 From: Sergey Zhemzhitsky Date: Tue, 25 Apr 2017 09:18:36 +0100 Subject: [PATCH 344/512] [SPARK-20404][CORE] Using Option(name) instead of Some(name) Using Option(name) instead of Some(name) to prevent runtime failures when using accumulators created like the following ``` sparkContext.accumulator(0, null) ``` Author: Sergey Zhemzhitsky Closes #17740 from szhem/SPARK-20404-null-acc-names. --- core/src/main/scala/org/apache/spark/SparkContext.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 99efc4893fda..0ec1bdd39b2f 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -1350,7 +1350,7 @@ class SparkContext(config: SparkConf) extends Logging { @deprecated("use AccumulatorV2", "2.0.0") def accumulator[T](initialValue: T, name: String)(implicit param: AccumulatorParam[T]) : Accumulator[T] = { - val acc = new Accumulator(initialValue, param, Some(name)) + val acc = new Accumulator(initialValue, param, Option(name)) cleaner.foreach(_.registerAccumulatorForCleanup(acc.newAcc)) acc } @@ -1379,7 +1379,7 @@ class SparkContext(config: SparkConf) extends Logging { @deprecated("use AccumulatorV2", "2.0.0") def accumulable[R, T](initialValue: R, name: String)(implicit param: AccumulableParam[R, T]) : Accumulable[R, T] = { - val acc = new Accumulable(initialValue, param, Some(name)) + val acc = new Accumulable(initialValue, param, Option(name)) cleaner.foreach(_.registerAccumulatorForCleanup(acc.newAcc)) acc } @@ -1414,7 +1414,7 @@ class SparkContext(config: SparkConf) extends Logging { * @note Accumulators must be registered before use, or it will throw exception. */ def register(acc: AccumulatorV2[_, _], name: String): Unit = { - acc.register(this, name = Some(name)) + acc.register(this, name = Option(name)) } /** From 387565cf14b490810f9479ff3adbf776e2edecdc Mon Sep 17 00:00:00 2001 From: wangmiao1981 Date: Tue, 25 Apr 2017 16:30:36 +0800 Subject: [PATCH 345/512] [SPARK-18901][FOLLOWUP][ML] Require in LR LogisticAggregator is redundant ## What changes were proposed in this pull request? This is a follow-up PR of #17478. ## How was this patch tested? Existing tests Author: wangmiao1981 Closes #17754 from wangmiao1981/followup. --- .../scala/org/apache/spark/ml/classification/LinearSVC.scala | 5 ++--- .../org/apache/spark/ml/regression/LinearRegression.scala | 5 ----- 2 files changed, 2 insertions(+), 8 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala index f76b14eeeb54..7507c7539d4e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala @@ -458,9 +458,7 @@ private class LinearSVCAggregator( */ def add(instance: Instance): this.type = { instance match { case Instance(label, weight, features) => - require(weight >= 0.0, s"instance weight, $weight has to be >= 0.0") - require(numFeatures == features.size, s"Dimensions mismatch when adding new instance." + - s" Expecting $numFeatures but got ${features.size}.") + if (weight == 0.0) return this val localFeaturesStd = bcFeaturesStd.value val localCoefficients = coefficientsArray @@ -512,6 +510,7 @@ private class LinearSVCAggregator( * @return This LinearSVCAggregator object. */ def merge(other: LinearSVCAggregator): this.type = { + if (other.weightSum != 0.0) { weightSum += other.weightSum lossSum += other.lossSum diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index f7e3c8fa5b6e..eaad54985229 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -971,9 +971,6 @@ private class LeastSquaresAggregator( */ def add(instance: Instance): this.type = { instance match { case Instance(label, weight, features) => - require(dim == features.size, s"Dimensions mismatch when adding new sample." + - s" Expecting $dim but got ${features.size}.") - require(weight >= 0.0, s"instance weight, $weight has to be >= 0.0") if (weight == 0.0) return this @@ -1005,8 +1002,6 @@ private class LeastSquaresAggregator( * @return This LeastSquaresAggregator object. */ def merge(other: LeastSquaresAggregator): this.type = { - require(dim == other.dim, s"Dimensions mismatch when merging with another " + - s"LeastSquaresAggregator. Expecting $dim but got ${other.dim}.") if (other.weightSum != 0) { totalCnt += other.totalCnt From 67eef47acfd26f1f0be3e8ef10453514f3655f62 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Tue, 25 Apr 2017 17:10:41 +0000 Subject: [PATCH 346/512] [SPARK-20449][ML] Upgrade breeze version to 0.13.1 ## What changes were proposed in this pull request? Upgrade breeze version to 0.13.1, which fixed some critical bugs of L-BFGS-B. ## How was this patch tested? Existing unit tests. Author: Yanbo Liang Closes #17746 from yanboliang/spark-20449. --- LICENSE | 1 + .../tests/testthat/test_mllib_classification.R | 10 +++++----- dev/deps/spark-deps-hadoop-2.6 | 12 +++++++----- dev/deps/spark-deps-hadoop-2.7 | 12 +++++++----- .../GeneralizedLinearRegression.scala | 4 ++-- .../spark/mllib/clustering/LDAModel.scala | 14 ++++---------- .../spark/mllib/optimization/LBFGSSuite.scala | 4 ++-- pom.xml | 2 +- python/pyspark/ml/classification.py | 18 ++++++++---------- 9 files changed, 37 insertions(+), 40 deletions(-) diff --git a/LICENSE b/LICENSE index 7950dd6ceb6d..c21032a1fd27 100644 --- a/LICENSE +++ b/LICENSE @@ -297,3 +297,4 @@ The text of each license is also included at licenses/LICENSE-[project].txt. (MIT License) RowsGroup (http://datatables.net/license/mit) (MIT License) jsonFormatter (http://www.jqueryscript.net/other/jQuery-Plugin-For-Pretty-JSON-Formatting-jsonFormatter.html) (MIT License) modernizr (https://github.com/Modernizr/Modernizr/blob/master/LICENSE) + (MIT License) machinist (https://github.com/typelevel/machinist) diff --git a/R/pkg/inst/tests/testthat/test_mllib_classification.R b/R/pkg/inst/tests/testthat/test_mllib_classification.R index 459254d271a5..af7cbdccf5d5 100644 --- a/R/pkg/inst/tests/testthat/test_mllib_classification.R +++ b/R/pkg/inst/tests/testthat/test_mllib_classification.R @@ -288,18 +288,18 @@ test_that("spark.mlp", { c(0, 0, 0, 0, 0, 5, 5, 5, 5, 5, 9, 9, 9, 9, 9)) mlpPredictions <- collect(select(predict(model, mlpTestDF), "prediction")) expect_equal(head(mlpPredictions$prediction, 10), - c("1.0", "1.0", "1.0", "1.0", "2.0", "1.0", "2.0", "2.0", "1.0", "0.0")) + c("1.0", "1.0", "2.0", "1.0", "2.0", "1.0", "2.0", "2.0", "1.0", "0.0")) model <- spark.mlp(df, label ~ features, layers = c(4, 3), maxIter = 2, initialWeights = c(0.0, 0.0, 0.0, 0.0, 0.0, 5.0, 5.0, 5.0, 5.0, 5.0, 9.0, 9.0, 9.0, 9.0, 9.0)) mlpPredictions <- collect(select(predict(model, mlpTestDF), "prediction")) expect_equal(head(mlpPredictions$prediction, 10), - c("1.0", "1.0", "1.0", "1.0", "2.0", "1.0", "2.0", "2.0", "1.0", "0.0")) + c("1.0", "1.0", "2.0", "1.0", "2.0", "1.0", "2.0", "2.0", "1.0", "0.0")) model <- spark.mlp(df, label ~ features, layers = c(4, 3), maxIter = 2) mlpPredictions <- collect(select(predict(model, mlpTestDF), "prediction")) expect_equal(head(mlpPredictions$prediction, 10), - c("1.0", "1.0", "1.0", "1.0", "0.0", "1.0", "0.0", "2.0", "1.0", "0.0")) + c("1.0", "1.0", "1.0", "1.0", "0.0", "1.0", "0.0", "0.0", "1.0", "0.0")) # Test formula works well df <- suppressWarnings(createDataFrame(iris)) @@ -310,8 +310,8 @@ test_that("spark.mlp", { expect_equal(summary$numOfOutputs, 3) expect_equal(summary$layers, c(4, 3)) expect_equal(length(summary$weights), 15) - expect_equal(head(summary$weights, 5), list(-1.1957257, -5.2693685, 7.4489734, -6.3751413, - -10.2376130), tolerance = 1e-6) + expect_equal(head(summary$weights, 5), list(-0.5793153, -4.652961, 6.216155, -6.649478, + -10.51147), tolerance = 1e-3) }) test_that("spark.naiveBayes", { diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6 index 73dc1f9a1398..9287bd47cf11 100644 --- a/dev/deps/spark-deps-hadoop-2.6 +++ b/dev/deps/spark-deps-hadoop-2.6 @@ -19,8 +19,8 @@ avro-mapred-1.7.7-hadoop2.jar base64-2.3.8.jar bcprov-jdk15on-1.51.jar bonecp-0.8.0.RELEASE.jar -breeze-macros_2.11-0.12.jar -breeze_2.11-0.12.jar +breeze-macros_2.11-0.13.1.jar +breeze_2.11-0.13.1.jar calcite-avatica-1.2.0-incubating.jar calcite-core-1.2.0-incubating.jar calcite-linq4j-1.2.0-incubating.jar @@ -129,6 +129,8 @@ libfb303-0.9.3.jar libthrift-0.9.3.jar log4j-1.2.17.jar lz4-1.3.0.jar +machinist_2.11-0.6.1.jar +macro-compat_2.11-1.1.1.jar mail-1.4.7.jar mesos-1.0.0-shaded-protobuf.jar metrics-core-3.1.2.jar @@ -162,13 +164,13 @@ scala-parser-combinators_2.11-1.0.4.jar scala-reflect-2.11.8.jar scala-xml_2.11-1.0.2.jar scalap-2.11.8.jar -shapeless_2.11-2.0.0.jar +shapeless_2.11-2.3.2.jar slf4j-api-1.7.16.jar slf4j-log4j12-1.7.16.jar snappy-0.2.jar snappy-java-1.1.2.6.jar -spire-macros_2.11-0.7.4.jar -spire_2.11-0.7.4.jar +spire-macros_2.11-0.13.0.jar +spire_2.11-0.13.0.jar stax-api-1.0-2.jar stax-api-1.0.1.jar stream-2.7.0.jar diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index 6bf0923a1d75..ab1de3d3dd8a 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -19,8 +19,8 @@ avro-mapred-1.7.7-hadoop2.jar base64-2.3.8.jar bcprov-jdk15on-1.51.jar bonecp-0.8.0.RELEASE.jar -breeze-macros_2.11-0.12.jar -breeze_2.11-0.12.jar +breeze-macros_2.11-0.13.1.jar +breeze_2.11-0.13.1.jar calcite-avatica-1.2.0-incubating.jar calcite-core-1.2.0-incubating.jar calcite-linq4j-1.2.0-incubating.jar @@ -130,6 +130,8 @@ libfb303-0.9.3.jar libthrift-0.9.3.jar log4j-1.2.17.jar lz4-1.3.0.jar +machinist_2.11-0.6.1.jar +macro-compat_2.11-1.1.1.jar mail-1.4.7.jar mesos-1.0.0-shaded-protobuf.jar metrics-core-3.1.2.jar @@ -163,13 +165,13 @@ scala-parser-combinators_2.11-1.0.4.jar scala-reflect-2.11.8.jar scala-xml_2.11-1.0.2.jar scalap-2.11.8.jar -shapeless_2.11-2.0.0.jar +shapeless_2.11-2.3.2.jar slf4j-api-1.7.16.jar slf4j-log4j12-1.7.16.jar snappy-0.2.jar snappy-java-1.1.2.6.jar -spire-macros_2.11-0.7.4.jar -spire_2.11-0.7.4.jar +spire-macros_2.11-0.13.0.jar +spire_2.11-0.13.0.jar stax-api-1.0-2.jar stax-api-1.0.1.jar stream-2.7.0.jar diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala index d6093a01c671..bff0d9bbb46f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala @@ -894,10 +894,10 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine private[regression] object Probit extends Link("probit") { - override def link(mu: Double): Double = dist.Gaussian(0.0, 1.0).icdf(mu) + override def link(mu: Double): Double = dist.Gaussian(0.0, 1.0).inverseCdf(mu) override def deriv(mu: Double): Double = { - 1.0 / dist.Gaussian(0.0, 1.0).pdf(dist.Gaussian(0.0, 1.0).icdf(mu)) + 1.0 / dist.Gaussian(0.0, 1.0).pdf(dist.Gaussian(0.0, 1.0).inverseCdf(mu)) } override def unlink(eta: Double): Double = dist.Gaussian(0.0, 1.0).cdf(eta) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala index 7fd722a33292..15b723dadcff 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala @@ -788,20 +788,14 @@ class DistributedLDAModel private[clustering] ( @Since("1.5.0") def topTopicsPerDocument(k: Int): RDD[(Long, Array[Int], Array[Double])] = { graph.vertices.filter(LDA.isDocumentVertex).map { case (docID, topicCounts) => - // TODO: Remove work-around for the breeze bug. - // https://github.com/scalanlp/breeze/issues/561 - val topIndices = if (k == topicCounts.length) { - Seq.range(0, k) - } else { - argtopk(topicCounts, k) - } + val topIndices = argtopk(topicCounts, k) val sumCounts = sum(topicCounts) val weights = if (sumCounts != 0) { - topicCounts(topIndices) / sumCounts + topicCounts(topIndices).toArray.map(_ / sumCounts) } else { - topicCounts(topIndices) + topicCounts(topIndices).toArray } - (docID.toLong, topIndices.toArray, weights.toArray) + (docID.toLong, topIndices.toArray, weights) } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala index 572959200f47..3d6a9f8d84ca 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala @@ -191,8 +191,8 @@ class LBFGSSuite extends SparkFunSuite with MLlibTestSparkContext with Matchers // With smaller convergenceTol, it takes more steps. assert(lossLBFGS3.length > lossLBFGS2.length) - // Based on observation, lossLBFGS2 runs 5 iterations, no theoretically guaranteed. - assert(lossLBFGS3.length == 6) + // Based on observation, lossLBFGS3 runs 7 iterations, no theoretically guaranteed. + assert(lossLBFGS3.length == 7) assert((lossLBFGS3(4) - lossLBFGS3(5)) / lossLBFGS3(4) < convergenceTol) } diff --git a/pom.xml b/pom.xml index a65692e0d131..b6654c1411d2 100644 --- a/pom.xml +++ b/pom.xml @@ -658,7 +658,7 @@ org.scalanlp breeze_${scala.binary.version} - 0.12 + 0.13.1 diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index b4fc357e42d7..864968390ace 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -190,9 +190,9 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti >>> blor = LogisticRegression(maxIter=5, regParam=0.01, weightCol="weight") >>> blorModel = blor.fit(bdf) >>> blorModel.coefficients - DenseVector([5.5...]) + DenseVector([5.4...]) >>> blorModel.intercept - -2.68... + -2.63... >>> mdf = sc.parallelize([ ... Row(label=1.0, weight=2.0, features=Vectors.dense(1.0)), ... Row(label=0.0, weight=2.0, features=Vectors.sparse(1, [], [])), @@ -200,12 +200,10 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti >>> mlor = LogisticRegression(maxIter=5, regParam=0.01, weightCol="weight", ... family="multinomial") >>> mlorModel = mlor.fit(mdf) - >>> print(mlorModel.coefficientMatrix) - DenseMatrix([[-2.3...], - [ 0.2...], - [ 2.1... ]]) + >>> mlorModel.coefficientMatrix + DenseMatrix(3, 1, [-2.3..., 0.2..., 2.1...], 1) >>> mlorModel.interceptVector - DenseVector([2.0..., 0.8..., -2.8...]) + DenseVector([2.1..., 0.6..., -2.8...]) >>> test0 = sc.parallelize([Row(features=Vectors.dense(-1.0))]).toDF() >>> result = blorModel.transform(test0).head() >>> result.prediction @@ -213,7 +211,7 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti >>> result.probability DenseVector([0.99..., 0.00...]) >>> result.rawPrediction - DenseVector([8.22..., -8.22...]) + DenseVector([8.12..., -8.12...]) >>> test1 = sc.parallelize([Row(features=Vectors.sparse(1, [0], [1.0]))]).toDF() >>> blorModel.transform(test1).head().prediction 1.0 @@ -1490,9 +1488,9 @@ class OneVsRest(Estimator, OneVsRestParams, MLReadable, MLWritable): >>> ovr = OneVsRest(classifier=lr) >>> model = ovr.fit(df) >>> [x.coefficients for x in model.models] - [DenseVector([3.3925, 1.8785]), DenseVector([-4.3016, -6.3163]), DenseVector([-4.5855, 6.1785])] + [DenseVector([4.9791, 2.426]), DenseVector([-4.1198, -5.9326]), DenseVector([-3.314, 5.2423])] >>> [x.intercept for x in model.models] - [-3.64747..., 2.55078..., -1.10165...] + [-5.06544..., 2.30341..., -1.29133...] >>> test0 = sc.parallelize([Row(features=Vectors.dense(-1.0, 0.0))]).toDF() >>> model.transform(test0).head().prediction 1.0 From 0a7f5f2798b6e8b2ba15e8b3aa07d5953ad1c695 Mon Sep 17 00:00:00 2001 From: ding Date: Tue, 25 Apr 2017 11:20:32 -0700 Subject: [PATCH 347/512] [SPARK-5484][GRAPHX] Periodically do checkpoint in Pregel ## What changes were proposed in this pull request? Pregel-based iterative algorithms with more than ~50 iterations begin to slow down and eventually fail with a StackOverflowError due to Spark's lack of support for long lineage chains. This PR causes Pregel to checkpoint the graph periodically if the checkpoint directory is set. This PR moves PeriodicGraphCheckpointer.scala from mllib to graphx, moves PeriodicRDDCheckpointer.scala, PeriodicCheckpointer.scala from mllib to core ## How was this patch tested? unit tests, manual tests (Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) (If this patch involves UI changes, please attach a screenshot; otherwise, remove this) Author: ding Author: dding3 Author: Michael Allman Closes #15125 from dding3/cp2_pregel. --- .../main/scala/org/apache/spark/rdd/RDD.scala | 4 +- .../rdd/util}/PeriodicRDDCheckpointer.scala | 3 +- .../spark/util}/PeriodicCheckpointer.scala | 14 ++- .../org/apache/spark/rdd/SortingSuite.scala | 2 +- .../util}/PeriodicRDDCheckpointerSuite.scala | 8 +- docs/configuration.md | 14 +++ docs/graphx-programming-guide.md | 9 +- .../org/apache/spark/graphx/Pregel.scala | 25 ++++- .../util}/PeriodicGraphCheckpointer.scala | 13 ++- .../PeriodicGraphCheckpointerSuite.scala | 105 +++++++++--------- .../org/apache/spark/ml/clustering/LDA.scala | 3 +- .../ml/tree/impl/GradientBoostedTrees.scala | 2 +- .../spark/mllib/clustering/LDAOptimizer.scala | 2 +- 13 files changed, 128 insertions(+), 76 deletions(-) rename {mllib/src/main/scala/org/apache/spark/mllib/impl => core/src/main/scala/org/apache/spark/rdd/util}/PeriodicRDDCheckpointer.scala (97%) rename {mllib/src/main/scala/org/apache/spark/mllib/impl => core/src/main/scala/org/apache/spark/util}/PeriodicCheckpointer.scala (95%) rename {mllib/src/test/scala/org/apache/spark/mllib/impl => core/src/test/scala/org/apache/spark/util}/PeriodicRDDCheckpointerSuite.scala (96%) rename {mllib/src/main/scala/org/apache/spark/mllib/impl => graphx/src/main/scala/org/apache/spark/graphx/util}/PeriodicGraphCheckpointer.scala (91%) rename {mllib/src/test/scala/org/apache/spark/mllib/impl => graphx/src/test/scala/org/apache/spark/graphx/util}/PeriodicGraphCheckpointerSuite.scala (70%) diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index e524675332d1..63a87e7f09d8 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -41,7 +41,7 @@ import org.apache.spark.partial.GroupedCountEvaluator import org.apache.spark.partial.PartialResult import org.apache.spark.storage.{RDDBlockId, StorageLevel} import org.apache.spark.util.{BoundedPriorityQueue, Utils} -import org.apache.spark.util.collection.OpenHashMap +import org.apache.spark.util.collection.{OpenHashMap, Utils => collectionUtils} import org.apache.spark.util.random.{BernoulliCellSampler, BernoulliSampler, PoissonSampler, SamplingUtils} @@ -1420,7 +1420,7 @@ abstract class RDD[T: ClassTag]( val mapRDDs = mapPartitions { items => // Priority keeps the largest elements, so let's reverse the ordering. val queue = new BoundedPriorityQueue[T](num)(ord.reverse) - queue ++= util.collection.Utils.takeOrdered(items, num)(ord) + queue ++= collectionUtils.takeOrdered(items, num)(ord) Iterator.single(queue) } if (mapRDDs.partitions.length == 0) { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointer.scala b/core/src/main/scala/org/apache/spark/rdd/util/PeriodicRDDCheckpointer.scala similarity index 97% rename from mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointer.scala rename to core/src/main/scala/org/apache/spark/rdd/util/PeriodicRDDCheckpointer.scala index 145dc22b7428..ab72addb2466 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointer.scala +++ b/core/src/main/scala/org/apache/spark/rdd/util/PeriodicRDDCheckpointer.scala @@ -15,11 +15,12 @@ * limitations under the License. */ -package org.apache.spark.mllib.impl +package org.apache.spark.rdd.util import org.apache.spark.SparkContext import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel +import org.apache.spark.util.PeriodicCheckpointer /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala b/core/src/main/scala/org/apache/spark/util/PeriodicCheckpointer.scala similarity index 95% rename from mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala rename to core/src/main/scala/org/apache/spark/util/PeriodicCheckpointer.scala index 4dd498cd91b4..ce06e18879a4 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala +++ b/core/src/main/scala/org/apache/spark/util/PeriodicCheckpointer.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.mllib.impl +package org.apache.spark.util import scala.collection.mutable @@ -58,7 +58,7 @@ import org.apache.spark.storage.StorageLevel * @param sc SparkContext for the Datasets given to this checkpointer * @tparam T Dataset type, such as RDD[Double] */ -private[mllib] abstract class PeriodicCheckpointer[T]( +private[spark] abstract class PeriodicCheckpointer[T]( val checkpointInterval: Int, val sc: SparkContext) extends Logging { @@ -127,6 +127,16 @@ private[mllib] abstract class PeriodicCheckpointer[T]( /** Get list of checkpoint files for this given Dataset */ protected def getCheckpointFiles(data: T): Iterable[String] + /** + * Call this to unpersist the Dataset. + */ + def unpersistDataSet(): Unit = { + while (persistedQueue.nonEmpty) { + val dataToUnpersist = persistedQueue.dequeue() + unpersist(dataToUnpersist) + } + } + /** * Call this at the end to delete any remaining checkpoint files. */ diff --git a/core/src/test/scala/org/apache/spark/rdd/SortingSuite.scala b/core/src/test/scala/org/apache/spark/rdd/SortingSuite.scala index f9a7f151823a..7f20206202cb 100644 --- a/core/src/test/scala/org/apache/spark/rdd/SortingSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/SortingSuite.scala @@ -135,7 +135,7 @@ class SortingSuite extends SparkFunSuite with SharedSparkContext with Matchers w } test("get a range of elements in an array not partitioned by a range partitioner") { - val pairArr = util.Random.shuffle((1 to 1000).toList).map(x => (x, x)) + val pairArr = scala.util.Random.shuffle((1 to 1000).toList).map(x => (x, x)) val pairs = sc.parallelize(pairArr, 10) val range = pairs.filterByRange(200, 800).collect() assert((800 to 200 by -1).toArray.sorted === range.map(_._1).sorted) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointerSuite.scala b/core/src/test/scala/org/apache/spark/util/PeriodicRDDCheckpointerSuite.scala similarity index 96% rename from mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointerSuite.scala rename to core/src/test/scala/org/apache/spark/util/PeriodicRDDCheckpointerSuite.scala index 14adf8c29fc6..f9e1b791c86e 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointerSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/PeriodicRDDCheckpointerSuite.scala @@ -15,18 +15,18 @@ * limitations under the License. */ -package org.apache.spark.mllib.impl +package org.apache.spark.utils import org.apache.hadoop.fs.Path -import org.apache.spark.{SparkContext, SparkFunSuite} -import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.{SharedSparkContext, SparkContext, SparkFunSuite} import org.apache.spark.rdd.RDD +import org.apache.spark.rdd.util.PeriodicRDDCheckpointer import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils -class PeriodicRDDCheckpointerSuite extends SparkFunSuite with MLlibTestSparkContext { +class PeriodicRDDCheckpointerSuite extends SparkFunSuite with SharedSparkContext { import PeriodicRDDCheckpointerSuite._ diff --git a/docs/configuration.md b/docs/configuration.md index 6b65d2bcb83e..87b76322cae5 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -2149,6 +2149,20 @@ showDF(properties, numRows = 200, truncate = FALSE) +### GraphX + + + + + + + + +
    Property NameDefaultMeaning
    spark.graphx.pregel.checkpointInterval-1 + Checkpoint interval for graph and message in Pregel. It used to avoid stackOverflowError due to long lineage chains + after lots of iterations. The checkpoint is disabled by default. +
    + ### Deploy diff --git a/docs/graphx-programming-guide.md b/docs/graphx-programming-guide.md index e271b28fb4f2..76aa7b405e18 100644 --- a/docs/graphx-programming-guide.md +++ b/docs/graphx-programming-guide.md @@ -708,7 +708,9 @@ messages remaining. > messaging function. These constraints allow additional optimization within GraphX. The following is the type signature of the [Pregel operator][GraphOps.pregel] as well as a *sketch* -of its implementation (note calls to graph.cache have been removed): +of its implementation (note: to avoid stackOverflowError due to long lineage chains, pregel support periodcally +checkpoint graph and messages by setting "spark.graphx.pregel.checkpointInterval" to a positive number, +say 10. And set checkpoint directory as well using SparkContext.setCheckpointDir(directory: String)): {% highlight scala %} class GraphOps[VD, ED] { @@ -722,6 +724,7 @@ class GraphOps[VD, ED] { : Graph[VD, ED] = { // Receive the initial message at each vertex var g = mapVertices( (vid, vdata) => vprog(vid, vdata, initialMsg) ).cache() + // compute the messages var messages = g.mapReduceTriplets(sendMsg, mergeMsg) var activeMessages = messages.count() @@ -734,8 +737,8 @@ class GraphOps[VD, ED] { // Send new messages, skipping edges where neither side received a message. We must cache // messages so it can be materialized on the next line, allowing us to uncache the previous // iteration. - messages = g.mapReduceTriplets( - sendMsg, mergeMsg, Some((oldMessages, activeDirection))).cache() + messages = GraphXUtils.mapReduceTriplets( + g, sendMsg, mergeMsg, Some((oldMessages, activeDirection))).cache() activeMessages = messages.count() i += 1 } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala b/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala index 646462b4a835..755c6febc48e 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala @@ -19,7 +19,10 @@ package org.apache.spark.graphx import scala.reflect.ClassTag +import org.apache.spark.graphx.util.PeriodicGraphCheckpointer import org.apache.spark.internal.Logging +import org.apache.spark.rdd.RDD +import org.apache.spark.rdd.util.PeriodicRDDCheckpointer /** * Implements a Pregel-like bulk-synchronous message-passing API. @@ -122,27 +125,39 @@ object Pregel extends Logging { require(maxIterations > 0, s"Maximum number of iterations must be greater than 0," + s" but got ${maxIterations}") - var g = graph.mapVertices((vid, vdata) => vprog(vid, vdata, initialMsg)).cache() + val checkpointInterval = graph.vertices.sparkContext.getConf + .getInt("spark.graphx.pregel.checkpointInterval", -1) + var g = graph.mapVertices((vid, vdata) => vprog(vid, vdata, initialMsg)) + val graphCheckpointer = new PeriodicGraphCheckpointer[VD, ED]( + checkpointInterval, graph.vertices.sparkContext) + graphCheckpointer.update(g) + // compute the messages var messages = GraphXUtils.mapReduceTriplets(g, sendMsg, mergeMsg) + val messageCheckpointer = new PeriodicRDDCheckpointer[(VertexId, A)]( + checkpointInterval, graph.vertices.sparkContext) + messageCheckpointer.update(messages.asInstanceOf[RDD[(VertexId, A)]]) var activeMessages = messages.count() + // Loop var prevG: Graph[VD, ED] = null var i = 0 while (activeMessages > 0 && i < maxIterations) { // Receive the messages and update the vertices. prevG = g - g = g.joinVertices(messages)(vprog).cache() + g = g.joinVertices(messages)(vprog) + graphCheckpointer.update(g) val oldMessages = messages // Send new messages, skipping edges where neither side received a message. We must cache // messages so it can be materialized on the next line, allowing us to uncache the previous // iteration. messages = GraphXUtils.mapReduceTriplets( - g, sendMsg, mergeMsg, Some((oldMessages, activeDirection))).cache() + g, sendMsg, mergeMsg, Some((oldMessages, activeDirection))) // The call to count() materializes `messages` and the vertices of `g`. This hides oldMessages // (depended on by the vertices of g) and the vertices of prevG (depended on by oldMessages // and the vertices of g). + messageCheckpointer.update(messages.asInstanceOf[RDD[(VertexId, A)]]) activeMessages = messages.count() logInfo("Pregel finished iteration " + i) @@ -154,7 +169,9 @@ object Pregel extends Logging { // count the iteration i += 1 } - messages.unpersist(blocking = false) + messageCheckpointer.unpersistDataSet() + graphCheckpointer.deleteAllCheckpoints() + messageCheckpointer.deleteAllCheckpoints() g } // end of apply diff --git a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala b/graphx/src/main/scala/org/apache/spark/graphx/util/PeriodicGraphCheckpointer.scala similarity index 91% rename from mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala rename to graphx/src/main/scala/org/apache/spark/graphx/util/PeriodicGraphCheckpointer.scala index 80074897567e..fda501aa757d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/util/PeriodicGraphCheckpointer.scala @@ -15,11 +15,12 @@ * limitations under the License. */ -package org.apache.spark.mllib.impl +package org.apache.spark.graphx.util import org.apache.spark.SparkContext import org.apache.spark.graphx.Graph import org.apache.spark.storage.StorageLevel +import org.apache.spark.util.PeriodicCheckpointer /** @@ -74,9 +75,8 @@ import org.apache.spark.storage.StorageLevel * @tparam VD Vertex descriptor type * @tparam ED Edge descriptor type * - * TODO: Move this out of MLlib? */ -private[mllib] class PeriodicGraphCheckpointer[VD, ED]( +private[spark] class PeriodicGraphCheckpointer[VD, ED]( checkpointInterval: Int, sc: SparkContext) extends PeriodicCheckpointer[Graph[VD, ED]](checkpointInterval, sc) { @@ -87,10 +87,13 @@ private[mllib] class PeriodicGraphCheckpointer[VD, ED]( override protected def persist(data: Graph[VD, ED]): Unit = { if (data.vertices.getStorageLevel == StorageLevel.NONE) { - data.vertices.persist() + /* We need to use cache because persist does not honor the default storage level requested + * when constructing the graph. Only cache does that. + */ + data.vertices.cache() } if (data.edges.getStorageLevel == StorageLevel.NONE) { - data.edges.persist() + data.edges.cache() } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/util/PeriodicGraphCheckpointerSuite.scala similarity index 70% rename from mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala rename to graphx/src/test/scala/org/apache/spark/graphx/util/PeriodicGraphCheckpointerSuite.scala index a13e7f63a929..e0c65e6940f6 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/util/PeriodicGraphCheckpointerSuite.scala @@ -15,77 +15,81 @@ * limitations under the License. */ -package org.apache.spark.mllib.impl +package org.apache.spark.graphx.util import org.apache.hadoop.fs.Path import org.apache.spark.{SparkContext, SparkFunSuite} -import org.apache.spark.graphx.{Edge, Graph} -import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.graphx.{Edge, Graph, LocalSparkContext} import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils -class PeriodicGraphCheckpointerSuite extends SparkFunSuite with MLlibTestSparkContext { +class PeriodicGraphCheckpointerSuite extends SparkFunSuite with LocalSparkContext { import PeriodicGraphCheckpointerSuite._ test("Persisting") { var graphsToCheck = Seq.empty[GraphToCheck] - val graph1 = createGraph(sc) - val checkpointer = - new PeriodicGraphCheckpointer[Double, Double](10, graph1.vertices.sparkContext) - checkpointer.update(graph1) - graphsToCheck = graphsToCheck :+ GraphToCheck(graph1, 1) - checkPersistence(graphsToCheck, 1) - - var iteration = 2 - while (iteration < 9) { - val graph = createGraph(sc) - checkpointer.update(graph) - graphsToCheck = graphsToCheck :+ GraphToCheck(graph, iteration) - checkPersistence(graphsToCheck, iteration) - iteration += 1 + withSpark { sc => + val graph1 = createGraph(sc) + val checkpointer = + new PeriodicGraphCheckpointer[Double, Double](10, graph1.vertices.sparkContext) + checkpointer.update(graph1) + graphsToCheck = graphsToCheck :+ GraphToCheck(graph1, 1) + checkPersistence(graphsToCheck, 1) + + var iteration = 2 + while (iteration < 9) { + val graph = createGraph(sc) + checkpointer.update(graph) + graphsToCheck = graphsToCheck :+ GraphToCheck(graph, iteration) + checkPersistence(graphsToCheck, iteration) + iteration += 1 + } } } test("Checkpointing") { - val tempDir = Utils.createTempDir() - val path = tempDir.toURI.toString - val checkpointInterval = 2 - var graphsToCheck = Seq.empty[GraphToCheck] - sc.setCheckpointDir(path) - val graph1 = createGraph(sc) - val checkpointer = new PeriodicGraphCheckpointer[Double, Double]( - checkpointInterval, graph1.vertices.sparkContext) - checkpointer.update(graph1) - graph1.edges.count() - graph1.vertices.count() - graphsToCheck = graphsToCheck :+ GraphToCheck(graph1, 1) - checkCheckpoint(graphsToCheck, 1, checkpointInterval) - - var iteration = 2 - while (iteration < 9) { - val graph = createGraph(sc) - checkpointer.update(graph) - graph.vertices.count() - graph.edges.count() - graphsToCheck = graphsToCheck :+ GraphToCheck(graph, iteration) - checkCheckpoint(graphsToCheck, iteration, checkpointInterval) - iteration += 1 - } + withSpark { sc => + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + val checkpointInterval = 2 + var graphsToCheck = Seq.empty[GraphToCheck] + sc.setCheckpointDir(path) + val graph1 = createGraph(sc) + val checkpointer = new PeriodicGraphCheckpointer[Double, Double]( + checkpointInterval, graph1.vertices.sparkContext) + checkpointer.update(graph1) + graph1.edges.count() + graph1.vertices.count() + graphsToCheck = graphsToCheck :+ GraphToCheck(graph1, 1) + checkCheckpoint(graphsToCheck, 1, checkpointInterval) + + var iteration = 2 + while (iteration < 9) { + val graph = createGraph(sc) + checkpointer.update(graph) + graph.vertices.count() + graph.edges.count() + graphsToCheck = graphsToCheck :+ GraphToCheck(graph, iteration) + checkCheckpoint(graphsToCheck, iteration, checkpointInterval) + iteration += 1 + } - checkpointer.deleteAllCheckpoints() - graphsToCheck.foreach { graph => - confirmCheckpointRemoved(graph.graph) - } + checkpointer.deleteAllCheckpoints() + graphsToCheck.foreach { graph => + confirmCheckpointRemoved(graph.graph) + } - Utils.deleteRecursively(tempDir) + Utils.deleteRecursively(tempDir) + } } } private object PeriodicGraphCheckpointerSuite { + private val defaultStorageLevel = StorageLevel.MEMORY_ONLY_SER case class GraphToCheck(graph: Graph[Double, Double], gIndex: Int) @@ -96,7 +100,8 @@ private object PeriodicGraphCheckpointerSuite { Edge[Double](3, 4, 0)) def createGraph(sc: SparkContext): Graph[Double, Double] = { - Graph.fromEdges[Double, Double](sc.parallelize(edges), 0) + Graph.fromEdges[Double, Double]( + sc.parallelize(edges), 0, defaultStorageLevel, defaultStorageLevel) } def checkPersistence(graphs: Seq[GraphToCheck], iteration: Int): Unit = { @@ -116,8 +121,8 @@ private object PeriodicGraphCheckpointerSuite { assert(graph.vertices.getStorageLevel == StorageLevel.NONE) assert(graph.edges.getStorageLevel == StorageLevel.NONE) } else { - assert(graph.vertices.getStorageLevel != StorageLevel.NONE) - assert(graph.edges.getStorageLevel != StorageLevel.NONE) + assert(graph.vertices.getStorageLevel == defaultStorageLevel) + assert(graph.edges.getStorageLevel == defaultStorageLevel) } } catch { case _: AssertionError => diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala index 2f50dc7c85f3..e3026c8efa82 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala @@ -36,7 +36,6 @@ import org.apache.spark.mllib.clustering.{DistributedLDAModel => OldDistributedL EMLDAOptimizer => OldEMLDAOptimizer, LDA => OldLDA, LDAModel => OldLDAModel, LDAOptimizer => OldLDAOptimizer, LocalLDAModel => OldLocalLDAModel, OnlineLDAOptimizer => OldOnlineLDAOptimizer} -import org.apache.spark.mllib.impl.PeriodicCheckpointer import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors} import org.apache.spark.mllib.linalg.MatrixImplicits._ import org.apache.spark.mllib.linalg.VectorImplicits._ @@ -45,9 +44,9 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession} import org.apache.spark.sql.functions.{col, monotonically_increasing_id, udf} import org.apache.spark.sql.types.StructType +import org.apache.spark.util.PeriodicCheckpointer import org.apache.spark.util.VersionUtils - private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasMaxIter with HasSeed with HasCheckpointInterval { diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala index 4c525c0714ec..ce2bd7b430f4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala @@ -21,12 +21,12 @@ import org.apache.spark.internal.Logging import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.linalg.Vector import org.apache.spark.ml.regression.{DecisionTreeRegressionModel, DecisionTreeRegressor} -import org.apache.spark.mllib.impl.PeriodicRDDCheckpointer import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.mllib.tree.configuration.{BoostingStrategy => OldBoostingStrategy} import org.apache.spark.mllib.tree.impurity.{Variance => OldVariance} import org.apache.spark.mllib.tree.loss.{Loss => OldLoss} import org.apache.spark.rdd.RDD +import org.apache.spark.rdd.util.PeriodicRDDCheckpointer import org.apache.spark.storage.StorageLevel diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala index 48bae4276c48..3697a9b46dd8 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala @@ -25,7 +25,7 @@ import breeze.stats.distributions.{Gamma, RandBasis} import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.graphx._ -import org.apache.spark.mllib.impl.PeriodicGraphCheckpointer +import org.apache.spark.graphx.util.PeriodicGraphCheckpointer import org.apache.spark.mllib.linalg.{DenseVector, Matrices, SparseVector, Vector, Vectors} import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel From caf392025ce21d701b503112060fa016d5eabe04 Mon Sep 17 00:00:00 2001 From: Sameer Agarwal Date: Tue, 25 Apr 2017 17:05:20 -0700 Subject: [PATCH 348/512] [SPARK-18127] Add hooks and extension points to Spark ## What changes were proposed in this pull request? This patch adds support for customizing the spark session by injecting user-defined custom extensions. This allows a user to add custom analyzer rules/checks, optimizer rules, planning strategies or even a customized parser. ## How was this patch tested? Unit Tests in SparkSessionExtensionSuite Author: Sameer Agarwal Closes #17724 from sameeragarwal/session-extensions. --- .../sql/catalyst/parser/ParseDriver.scala | 9 +- .../sql/catalyst/parser/ParserInterface.scala | 35 +++- .../spark/sql/internal/StaticSQLConf.scala | 6 + .../org/apache/spark/sql/SparkSession.scala | 45 ++++- .../spark/sql/SparkSessionExtensions.scala | 171 ++++++++++++++++++ .../internal/BaseSessionStateBuilder.scala | 33 +++- .../sql/SparkSessionExtensionSuite.scala | 144 +++++++++++++++ 7 files changed, 418 insertions(+), 25 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala index 80ab75cc17fa..dcccbd0ed8d6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala @@ -34,8 +34,7 @@ import org.apache.spark.sql.types.{DataType, StructType} abstract class AbstractSqlParser extends ParserInterface with Logging { /** Creates/Resolves DataType for a given SQL string. */ - def parseDataType(sqlText: String): DataType = parse(sqlText) { parser => - // TODO add this to the parser interface. + override def parseDataType(sqlText: String): DataType = parse(sqlText) { parser => astBuilder.visitSingleDataType(parser.singleDataType()) } @@ -50,8 +49,10 @@ abstract class AbstractSqlParser extends ParserInterface with Logging { } /** Creates FunctionIdentifier for a given SQL string. */ - def parseFunctionIdentifier(sqlText: String): FunctionIdentifier = parse(sqlText) { parser => - astBuilder.visitSingleFunctionIdentifier(parser.singleFunctionIdentifier()) + override def parseFunctionIdentifier(sqlText: String): FunctionIdentifier = { + parse(sqlText) { parser => + astBuilder.visitSingleFunctionIdentifier(parser.singleFunctionIdentifier()) + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserInterface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserInterface.scala index db3598bde04d..75240d219622 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserInterface.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserInterface.scala @@ -17,30 +17,51 @@ package org.apache.spark.sql.catalyst.parser +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{DataType, StructType} /** * Interface for a parser. */ +@DeveloperApi trait ParserInterface { - /** Creates LogicalPlan for a given SQL string. */ + /** + * Parse a string to a [[LogicalPlan]]. + */ + @throws[ParseException]("Text cannot be parsed to a LogicalPlan") def parsePlan(sqlText: String): LogicalPlan - /** Creates Expression for a given SQL string. */ + /** + * Parse a string to an [[Expression]]. + */ + @throws[ParseException]("Text cannot be parsed to an Expression") def parseExpression(sqlText: String): Expression - /** Creates TableIdentifier for a given SQL string. */ + /** + * Parse a string to a [[TableIdentifier]]. + */ + @throws[ParseException]("Text cannot be parsed to a TableIdentifier") def parseTableIdentifier(sqlText: String): TableIdentifier - /** Creates FunctionIdentifier for a given SQL string. */ + /** + * Parse a string to a [[FunctionIdentifier]]. + */ + @throws[ParseException]("Text cannot be parsed to a FunctionIdentifier") def parseFunctionIdentifier(sqlText: String): FunctionIdentifier /** - * Creates StructType for a given SQL string, which is a comma separated list of field - * definitions which will preserve the correct Hive metadata. + * Parse a string to a [[StructType]]. The passed SQL string should be a comma separated list + * of field definitions which will preserve the correct Hive metadata. */ + @throws[ParseException]("Text cannot be parsed to a schema") def parseTableSchema(sqlText: String): StructType + + /** + * Parse a string to a [[DataType]]. + */ + @throws[ParseException]("Text cannot be parsed to a DataType") + def parseDataType(sqlText: String): DataType } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala index af1a9cee2962..c6c0a605d89f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala @@ -81,4 +81,10 @@ object StaticSQLConf { "SQL configuration and the current database.") .booleanConf .createWithDefault(false) + + val SPARK_SESSION_EXTENSIONS = buildStaticConf("spark.sql.extensions") + .doc("Name of the class used to configure Spark Session extensions. The class should " + + "implement Function1[SparkSessionExtension, Unit], and must have a no-args constructor.") + .stringConf + .createOptional } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index 95f3463dfe62..a519492ed8f4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -38,7 +38,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Range} import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.ui.SQLListener -import org.apache.spark.sql.internal.{BaseSessionStateBuilder, CatalogImpl, SessionState, SessionStateBuilder, SharedState} +import org.apache.spark.sql.internal._ import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.sql.streaming._ @@ -77,11 +77,12 @@ import org.apache.spark.util.Utils class SparkSession private( @transient val sparkContext: SparkContext, @transient private val existingSharedState: Option[SharedState], - @transient private val parentSessionState: Option[SessionState]) + @transient private val parentSessionState: Option[SessionState], + @transient private[sql] val extensions: SparkSessionExtensions) extends Serializable with Closeable with Logging { self => private[sql] def this(sc: SparkContext) { - this(sc, None, None) + this(sc, None, None, new SparkSessionExtensions) } sparkContext.assertNotStopped() @@ -219,7 +220,7 @@ class SparkSession private( * @since 2.0.0 */ def newSession(): SparkSession = { - new SparkSession(sparkContext, Some(sharedState), parentSessionState = None) + new SparkSession(sparkContext, Some(sharedState), parentSessionState = None, extensions) } /** @@ -235,7 +236,7 @@ class SparkSession private( * implementation is Hive, this will initialize the metastore, which may take some time. */ private[sql] def cloneSession(): SparkSession = { - val result = new SparkSession(sparkContext, Some(sharedState), Some(sessionState)) + val result = new SparkSession(sparkContext, Some(sharedState), Some(sessionState), extensions) result.sessionState // force copy of SessionState result } @@ -754,6 +755,8 @@ object SparkSession { private[this] val options = new scala.collection.mutable.HashMap[String, String] + private[this] val extensions = new SparkSessionExtensions + private[this] var userSuppliedContext: Option[SparkContext] = None private[spark] def sparkContext(sparkContext: SparkContext): Builder = synchronized { @@ -847,6 +850,17 @@ object SparkSession { } } + /** + * Inject extensions into the [[SparkSession]]. This allows a user to add Analyzer rules, + * Optimizer rules, Planning Strategies or a customized parser. + * + * @since 2.2.0 + */ + def withExtensions(f: SparkSessionExtensions => Unit): Builder = { + f(extensions) + this + } + /** * Gets an existing [[SparkSession]] or, if there is no existing one, creates a new * one based on the options set in this builder. @@ -903,7 +917,26 @@ object SparkSession { } sc } - session = new SparkSession(sparkContext) + + // Initialize extensions if the user has defined a configurator class. + val extensionConfOption = sparkContext.conf.get(StaticSQLConf.SPARK_SESSION_EXTENSIONS) + if (extensionConfOption.isDefined) { + val extensionConfClassName = extensionConfOption.get + try { + val extensionConfClass = Utils.classForName(extensionConfClassName) + val extensionConf = extensionConfClass.newInstance() + .asInstanceOf[SparkSessionExtensions => Unit] + extensionConf(extensions) + } catch { + // Ignore the error if we cannot find the class or when the class has the wrong type. + case e @ (_: ClassCastException | + _: ClassNotFoundException | + _: NoClassDefFoundError) => + logWarning(s"Cannot use $extensionConfClassName to configure session extensions.", e) + } + } + + session = new SparkSession(sparkContext, None, None, extensions) options.foreach { case (k, v) => session.sessionState.conf.setConfString(k, v) } defaultSession.set(session) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala new file mode 100644 index 000000000000..f99c108161f9 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala @@ -0,0 +1,171 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import scala.collection.mutable + +import org.apache.spark.annotation.{DeveloperApi, Experimental, InterfaceStability} +import org.apache.spark.sql.catalyst.parser.ParserInterface +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.rules.Rule + +/** + * :: Experimental :: + * Holder for injection points to the [[SparkSession]]. We make NO guarantee about the stability + * regarding binary compatibility and source compatibility of methods here. + * + * This current provides the following extension points: + * - Analyzer Rules. + * - Check Analysis Rules + * - Optimizer Rules. + * - Planning Strategies. + * - Customized Parser. + * - (External) Catalog listeners. + * + * The extensions can be used by calling withExtension on the [[SparkSession.Builder]], for + * example: + * {{{ + * SparkSession.builder() + * .master("...") + * .conf("...", true) + * .withExtensions { extensions => + * extensions.injectResolutionRule { session => + * ... + * } + * extensions.injectParser { (session, parser) => + * ... + * } + * } + * .getOrCreate() + * }}} + * + * Note that none of the injected builders should assume that the [[SparkSession]] is fully + * initialized and should not touch the session's internals (e.g. the SessionState). + */ +@DeveloperApi +@Experimental +@InterfaceStability.Unstable +class SparkSessionExtensions { + type RuleBuilder = SparkSession => Rule[LogicalPlan] + type CheckRuleBuilder = SparkSession => LogicalPlan => Unit + type StrategyBuilder = SparkSession => Strategy + type ParserBuilder = (SparkSession, ParserInterface) => ParserInterface + + private[this] val resolutionRuleBuilders = mutable.Buffer.empty[RuleBuilder] + + /** + * Build the analyzer resolution `Rule`s using the given [[SparkSession]]. + */ + private[sql] def buildResolutionRules(session: SparkSession): Seq[Rule[LogicalPlan]] = { + resolutionRuleBuilders.map(_.apply(session)) + } + + /** + * Inject an analyzer resolution `Rule` builder into the [[SparkSession]]. These analyzer + * rules will be executed as part of the resolution phase of analysis. + */ + def injectResolutionRule(builder: RuleBuilder): Unit = { + resolutionRuleBuilders += builder + } + + private[this] val postHocResolutionRuleBuilders = mutable.Buffer.empty[RuleBuilder] + + /** + * Build the analyzer post-hoc resolution `Rule`s using the given [[SparkSession]]. + */ + private[sql] def buildPostHocResolutionRules(session: SparkSession): Seq[Rule[LogicalPlan]] = { + postHocResolutionRuleBuilders.map(_.apply(session)) + } + + /** + * Inject an analyzer `Rule` builder into the [[SparkSession]]. These analyzer + * rules will be executed after resolution. + */ + def injectPostHocResolutionRule(builder: RuleBuilder): Unit = { + postHocResolutionRuleBuilders += builder + } + + private[this] val checkRuleBuilders = mutable.Buffer.empty[CheckRuleBuilder] + + /** + * Build the check analysis `Rule`s using the given [[SparkSession]]. + */ + private[sql] def buildCheckRules(session: SparkSession): Seq[LogicalPlan => Unit] = { + checkRuleBuilders.map(_.apply(session)) + } + + /** + * Inject an check analysis `Rule` builder into the [[SparkSession]]. The injected rules will + * be executed after the analysis phase. A check analysis rule is used to detect problems with a + * LogicalPlan and should throw an exception when a problem is found. + */ + def injectCheckRule(builder: CheckRuleBuilder): Unit = { + checkRuleBuilders += builder + } + + private[this] val optimizerRules = mutable.Buffer.empty[RuleBuilder] + + private[sql] def buildOptimizerRules(session: SparkSession): Seq[Rule[LogicalPlan]] = { + optimizerRules.map(_.apply(session)) + } + + /** + * Inject an optimizer `Rule` builder into the [[SparkSession]]. The injected rules will be + * executed during the operator optimization batch. An optimizer rule is used to improve the + * quality of an analyzed logical plan; these rules should never modify the result of the + * LogicalPlan. + */ + def injectOptimizerRule(builder: RuleBuilder): Unit = { + optimizerRules += builder + } + + private[this] val plannerStrategyBuilders = mutable.Buffer.empty[StrategyBuilder] + + private[sql] def buildPlannerStrategies(session: SparkSession): Seq[Strategy] = { + plannerStrategyBuilders.map(_.apply(session)) + } + + /** + * Inject a planner `Strategy` builder into the [[SparkSession]]. The injected strategy will + * be used to convert a `LogicalPlan` into a executable + * [[org.apache.spark.sql.execution.SparkPlan]]. + */ + def injectPlannerStrategy(builder: StrategyBuilder): Unit = { + plannerStrategyBuilders += builder + } + + private[this] val parserBuilders = mutable.Buffer.empty[ParserBuilder] + + private[sql] def buildParser( + session: SparkSession, + initial: ParserInterface): ParserInterface = { + parserBuilders.foldLeft(initial) { (parser, builder) => + builder(session, parser) + } + } + + /** + * Inject a custom parser into the [[SparkSession]]. Note that the builder is passed a session + * and an initial parser. The latter allows for a user to create a partial parser and to delegate + * to the underlying parser for completeness. If a user injects more parsers, then the parsers + * are stacked on top of each other. + */ + def injectParser(builder: ParserBuilder): Unit = { + parserBuilders += builder + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala index df7c3678b780..2a801d87b12e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala @@ -18,8 +18,8 @@ package org.apache.spark.sql.internal import org.apache.spark.SparkConf import org.apache.spark.annotation.{Experimental, InterfaceStability} -import org.apache.spark.sql.{ExperimentalMethods, SparkSession, Strategy, UDFRegistration} -import org.apache.spark.sql.catalyst.analysis.{Analyzer, FunctionRegistry, ResolveTimeZone} +import org.apache.spark.sql.{ExperimentalMethods, SparkSession, UDFRegistration, _} +import org.apache.spark.sql.catalyst.analysis.{Analyzer, FunctionRegistry} import org.apache.spark.sql.catalyst.catalog.SessionCatalog import org.apache.spark.sql.catalyst.optimizer.Optimizer import org.apache.spark.sql.catalyst.parser.ParserInterface @@ -63,6 +63,11 @@ abstract class BaseSessionStateBuilder( */ protected def newBuilder: NewBuilder + /** + * Session extensions defined in the [[SparkSession]]. + */ + protected def extensions: SparkSessionExtensions = session.extensions + /** * Extract entries from `SparkConf` and put them in the `SQLConf` */ @@ -108,7 +113,9 @@ abstract class BaseSessionStateBuilder( * * Note: this depends on the `conf` field. */ - protected lazy val sqlParser: ParserInterface = new SparkSqlParser(conf) + protected lazy val sqlParser: ParserInterface = { + extensions.buildParser(session, new SparkSqlParser(conf)) + } /** * ResourceLoader that is used to load function resources and jars. @@ -171,7 +178,9 @@ abstract class BaseSessionStateBuilder( * * Note that this may NOT depend on the `analyzer` function. */ - protected def customResolutionRules: Seq[Rule[LogicalPlan]] = Nil + protected def customResolutionRules: Seq[Rule[LogicalPlan]] = { + extensions.buildResolutionRules(session) + } /** * Custom post resolution rules to add to the Analyzer. Prefer overriding this instead of @@ -179,7 +188,9 @@ abstract class BaseSessionStateBuilder( * * Note that this may NOT depend on the `analyzer` function. */ - protected def customPostHocResolutionRules: Seq[Rule[LogicalPlan]] = Nil + protected def customPostHocResolutionRules: Seq[Rule[LogicalPlan]] = { + extensions.buildPostHocResolutionRules(session) + } /** * Custom check rules to add to the Analyzer. Prefer overriding this instead of creating @@ -187,7 +198,9 @@ abstract class BaseSessionStateBuilder( * * Note that this may NOT depend on the `analyzer` function. */ - protected def customCheckRules: Seq[LogicalPlan => Unit] = Nil + protected def customCheckRules: Seq[LogicalPlan => Unit] = { + extensions.buildCheckRules(session) + } /** * Logical query plan optimizer. @@ -207,7 +220,9 @@ abstract class BaseSessionStateBuilder( * * Note that this may NOT depend on the `optimizer` function. */ - protected def customOperatorOptimizationRules: Seq[Rule[LogicalPlan]] = Nil + protected def customOperatorOptimizationRules: Seq[Rule[LogicalPlan]] = { + extensions.buildOptimizerRules(session) + } /** * Planner that converts optimized logical plans to physical plans. @@ -227,7 +242,9 @@ abstract class BaseSessionStateBuilder( * * Note that this may NOT depend on the `planner` function. */ - protected def customPlanningStrategies: Seq[Strategy] = Nil + protected def customPlanningStrategies: Seq[Strategy] = { + extensions.buildPlannerStrategies(session) + } /** * Create a query execution object. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala new file mode 100644 index 000000000000..43db79663322 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala @@ -0,0 +1,144 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParserInterface} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.{SparkPlan, SparkStrategy} +import org.apache.spark.sql.types.{DataType, StructType} + +/** + * Test cases for the [[SparkSessionExtensions]]. + */ +class SparkSessionExtensionSuite extends SparkFunSuite { + type ExtensionsBuilder = SparkSessionExtensions => Unit + private def create(builder: ExtensionsBuilder): ExtensionsBuilder = builder + + private def stop(spark: SparkSession): Unit = { + spark.stop() + SparkSession.clearActiveSession() + SparkSession.clearDefaultSession() + } + + private def withSession(builder: ExtensionsBuilder)(f: SparkSession => Unit): Unit = { + val spark = SparkSession.builder().master("local[1]").withExtensions(builder).getOrCreate() + try f(spark) finally { + stop(spark) + } + } + + test("inject analyzer rule") { + withSession(_.injectResolutionRule(MyRule)) { session => + assert(session.sessionState.analyzer.extendedResolutionRules.contains(MyRule(session))) + } + } + + test("inject check analysis rule") { + withSession(_.injectCheckRule(MyCheckRule)) { session => + assert(session.sessionState.analyzer.extendedCheckRules.contains(MyCheckRule(session))) + } + } + + test("inject optimizer rule") { + withSession(_.injectOptimizerRule(MyRule)) { session => + assert(session.sessionState.optimizer.batches.flatMap(_.rules).contains(MyRule(session))) + } + } + + test("inject spark planner strategy") { + withSession(_.injectPlannerStrategy(MySparkStrategy)) { session => + assert(session.sessionState.planner.strategies.contains(MySparkStrategy(session))) + } + } + + test("inject parser") { + val extension = create { extensions => + extensions.injectParser((_, _) => CatalystSqlParser) + } + withSession(extension) { session => + assert(session.sessionState.sqlParser == CatalystSqlParser) + } + } + + test("inject stacked parsers") { + val extension = create { extensions => + extensions.injectParser((_, _) => CatalystSqlParser) + extensions.injectParser(MyParser) + extensions.injectParser(MyParser) + } + withSession(extension) { session => + val parser = MyParser(session, MyParser(session, CatalystSqlParser)) + assert(session.sessionState.sqlParser == parser) + } + } + + test("use custom class for extensions") { + val session = SparkSession.builder() + .master("local[1]") + .config("spark.sql.extensions", classOf[MyExtensions].getCanonicalName) + .getOrCreate() + try { + assert(session.sessionState.planner.strategies.contains(MySparkStrategy(session))) + assert(session.sessionState.analyzer.extendedResolutionRules.contains(MyRule(session))) + } finally { + stop(session) + } + } +} + +case class MyRule(spark: SparkSession) extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = plan +} + +case class MyCheckRule(spark: SparkSession) extends (LogicalPlan => Unit) { + override def apply(plan: LogicalPlan): Unit = { } +} + +case class MySparkStrategy(spark: SparkSession) extends SparkStrategy { + override def apply(plan: LogicalPlan): Seq[SparkPlan] = Seq.empty +} + +case class MyParser(spark: SparkSession, delegate: ParserInterface) extends ParserInterface { + override def parsePlan(sqlText: String): LogicalPlan = + delegate.parsePlan(sqlText) + + override def parseExpression(sqlText: String): Expression = + delegate.parseExpression(sqlText) + + override def parseTableIdentifier(sqlText: String): TableIdentifier = + delegate.parseTableIdentifier(sqlText) + + override def parseFunctionIdentifier(sqlText: String): FunctionIdentifier = + delegate.parseFunctionIdentifier(sqlText) + + override def parseTableSchema(sqlText: String): StructType = + delegate.parseTableSchema(sqlText) + + override def parseDataType(sqlText: String): DataType = + delegate.parseDataType(sqlText) +} + +class MyExtensions extends (SparkSessionExtensions => Unit) { + def apply(e: SparkSessionExtensions): Unit = { + e.injectPlannerStrategy(MySparkStrategy) + e.injectResolutionRule(MyRule) + } +} From 57e1da39464131329318b723caa54df9f55fa54f Mon Sep 17 00:00:00 2001 From: Eric Wasserman Date: Wed, 26 Apr 2017 11:42:43 +0800 Subject: [PATCH 349/512] [SPARK-16548][SQL] Inconsistent error handling in JSON parsing SQL functions ## What changes were proposed in this pull request? change to using Jackson's `com.fasterxml.jackson.core.JsonFactory` public JsonParser createParser(String content) ## How was this patch tested? existing unit tests Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Eric Wasserman Closes #17693 from ewasserman/SPARK-20314. --- .../catalyst/expressions/jsonExpressions.scala | 12 +++++++++--- .../expressions/JsonExpressionsSuite.scala | 17 +++++++++++++++++ 2 files changed, 26 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index df4d406b84d6..9fb0ea68153d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.expressions -import java.io.{ByteArrayOutputStream, CharArrayWriter, StringWriter} +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, CharArrayWriter, InputStreamReader, StringWriter} import scala.util.parsing.combinator.RegexParsers @@ -149,7 +149,10 @@ case class GetJsonObject(json: Expression, path: Expression) if (parsed.isDefined) { try { - Utils.tryWithResource(jsonFactory.createParser(jsonStr.getBytes)) { parser => + /* We know the bytes are UTF-8 encoded. Pass a Reader to avoid having Jackson + detect character encoding which could fail for some malformed strings */ + Utils.tryWithResource(jsonFactory.createParser(new InputStreamReader( + new ByteArrayInputStream(jsonStr.getBytes), "UTF-8"))) { parser => val output = new ByteArrayOutputStream() val matched = Utils.tryWithResource( jsonFactory.createGenerator(output, JsonEncoding.UTF8)) { generator => @@ -393,7 +396,10 @@ case class JsonTuple(children: Seq[Expression]) } try { - Utils.tryWithResource(jsonFactory.createParser(json.getBytes)) { + /* We know the bytes are UTF-8 encoded. Pass a Reader to avoid having Jackson + detect character encoding which could fail for some malformed strings */ + Utils.tryWithResource(jsonFactory.createParser(new InputStreamReader( + new ByteArrayInputStream(json.getBytes), "UTF-8"))) { parser => parseRow(parser, input) } } catch { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala index c5b72235e5db..4402ad4e9a9e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala @@ -39,6 +39,10 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { |"fb:testid":"1234"} |""".stripMargin + /* invalid json with leading nulls would trigger java.io.CharConversionException + in Jackson's JsonFactory.createParser(byte[]) due to RFC-4627 encoding detection */ + val badJson = "\0\0\0A\1AAA" + test("$.store.bicycle") { checkEvaluation( GetJsonObject(Literal(json), Literal("$.store.bicycle")), @@ -224,6 +228,13 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { null) } + test("SPARK-16548: character conversion") { + checkEvaluation( + GetJsonObject(Literal(badJson), Literal("$.a")), + null + ) + } + test("non foldable literal") { checkEvaluation( GetJsonObject(NonFoldableLiteral(json), NonFoldableLiteral("$.fb:testid")), @@ -340,6 +351,12 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { InternalRow(null, null, null, null, null)) } + test("SPARK-16548: json_tuple - invalid json with leading nulls") { + checkJsonTuple( + JsonTuple(Literal(badJson) :: jsonTupleQuery), + InternalRow(null, null, null, null, null)) + } + test("json_tuple - preserve newlines") { checkJsonTuple( JsonTuple(Literal("{\"a\":\"b\nc\"}") :: Literal("a") :: Nil), From df58a95a33b739462dbe84e098839af2a8643d45 Mon Sep 17 00:00:00 2001 From: zero323 Date: Tue, 25 Apr 2017 22:00:45 -0700 Subject: [PATCH 350/512] [SPARK-20437][R] R wrappers for rollup and cube ## What changes were proposed in this pull request? - Add `rollup` and `cube` methods and corresponding generics. - Add short description to the vignette. ## How was this patch tested? - Existing unit tests. - Additional unit tests covering new features. - `check-cran.sh`. Author: zero323 Closes #17728 from zero323/SPARK-20437. --- R/pkg/NAMESPACE | 2 + R/pkg/R/DataFrame.R | 73 +++++++++++++++- R/pkg/R/generics.R | 8 ++ R/pkg/inst/tests/testthat/test_sparkSQL.R | 102 ++++++++++++++++++++++ R/pkg/vignettes/sparkr-vignettes.Rmd | 15 ++++ docs/sparkr.md | 30 +++++++ 6 files changed, 229 insertions(+), 1 deletion(-) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 95d5cc6d1c78..280046165848 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -101,6 +101,7 @@ exportMethods("arrange", "createOrReplaceTempView", "crossJoin", "crosstab", + "cube", "dapply", "dapplyCollect", "describe", @@ -143,6 +144,7 @@ exportMethods("arrange", "registerTempTable", "rename", "repartition", + "rollup", "sample", "sample_frac", "sampleBy", diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 88a138fd8eb1..cd6f03a13d7c 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -1321,7 +1321,7 @@ setMethod("toRDD", #' Groups the SparkDataFrame using the specified columns, so we can run aggregation on them. #' #' @param x a SparkDataFrame. -#' @param ... variable(s) (character names(s) or Column(s)) to group on. +#' @param ... character name(s) or Column(s) to group on. #' @return A GroupedData. #' @family SparkDataFrame functions #' @aliases groupBy,SparkDataFrame-method @@ -1337,6 +1337,7 @@ setMethod("toRDD", #' agg(groupBy(df, "department", "gender"), salary="avg", "age" -> "max") #' } #' @note groupBy since 1.4.0 +#' @seealso \link{agg}, \link{cube}, \link{rollup} setMethod("groupBy", signature(x = "SparkDataFrame"), function(x, ...) { @@ -3642,3 +3643,73 @@ setMethod("checkpoint", df <- callJMethod(x@sdf, "checkpoint", as.logical(eager)) dataFrame(df) }) + +#' cube +#' +#' Create a multi-dimensional cube for the SparkDataFrame using the specified columns. +#' +#' If grouping expression is missing \code{cube} creates a single global aggregate and is equivalent to +#' direct application of \link{agg}. +#' +#' @param x a SparkDataFrame. +#' @param ... character name(s) or Column(s) to group on. +#' @return A GroupedData. +#' @family SparkDataFrame functions +#' @aliases cube,SparkDataFrame-method +#' @rdname cube +#' @name cube +#' @export +#' @examples +#' \dontrun{ +#' df <- createDataFrame(mtcars) +#' mean(cube(df, "cyl", "gear", "am"), "mpg") +#' +#' # Following calls are equivalent +#' agg(cube(carsDF), mean(carsDF$mpg)) +#' agg(carsDF, mean(carsDF$mpg)) +#' } +#' @note cube since 2.3.0 +#' @seealso \link{agg}, \link{groupBy}, \link{rollup} +setMethod("cube", + signature(x = "SparkDataFrame"), + function(x, ...) { + cols <- list(...) + jcol <- lapply(cols, function(x) if (class(x) == "Column") x@jc else column(x)@jc) + sgd <- callJMethod(x@sdf, "cube", jcol) + groupedData(sgd) + }) + +#' rollup +#' +#' Create a multi-dimensional rollup for the SparkDataFrame using the specified columns. +#' +#' If grouping expression is missing \code{rollup} creates a single global aggregate and is equivalent to +#' direct application of \link{agg}. +#' +#' @param x a SparkDataFrame. +#' @param ... character name(s) or Column(s) to group on. +#' @return A GroupedData. +#' @family SparkDataFrame functions +#' @aliases rollup,SparkDataFrame-method +#' @rdname rollup +#' @name rollup +#' @export +#' @examples +#'\dontrun{ +#' df <- createDataFrame(mtcars) +#' mean(rollup(df, "cyl", "gear", "am"), "mpg") +#' +#' # Following calls are equivalent +#' agg(rollup(carsDF), mean(carsDF$mpg)) +#' agg(carsDF, mean(carsDF$mpg)) +#' } +#' @note rollup since 2.3.0 +#' @seealso \link{agg}, \link{cube}, \link{groupBy} +setMethod("rollup", + signature(x = "SparkDataFrame"), + function(x, ...) { + cols <- list(...) + jcol <- lapply(cols, function(x) if (class(x) == "Column") x@jc else column(x)@jc) + sgd <- callJMethod(x@sdf, "rollup", jcol) + groupedData(sgd) + }) diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 5e7a1c60c2b3..749ee9b54cc8 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -483,6 +483,10 @@ setGeneric("createOrReplaceTempView", # @export setGeneric("crossJoin", function(x, y) { standardGeneric("crossJoin") }) +#' @rdname cube +#' @export +setGeneric("cube", function(x, ...) { standardGeneric("cube") }) + #' @rdname dapply #' @export setGeneric("dapply", function(x, func, schema) { standardGeneric("dapply") }) @@ -631,6 +635,10 @@ setGeneric("sample", standardGeneric("sample") }) +#' @rdname rollup +#' @export +setGeneric("rollup", function(x, ...) { standardGeneric("rollup") }) + #' @rdname sample #' @export setGeneric("sample_frac", diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index c21ba2f1a138..2cef7191d4f2 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -1816,6 +1816,108 @@ test_that("pivot GroupedData column", { expect_error(collect(sum(pivot(groupBy(df, "year"), "course", list("R", "R")), "earnings"))) }) +test_that("test multi-dimensional aggregations with cube and rollup", { + df <- createDataFrame(data.frame( + id = 1:6, + year = c(2016, 2016, 2016, 2017, 2017, 2017), + salary = c(10000, 15000, 20000, 22000, 32000, 21000), + department = c("management", "rnd", "sales", "management", "rnd", "sales") + )) + + actual_cube <- collect( + orderBy( + agg( + cube(df, "year", "department"), + expr("sum(salary) AS total_salary"), expr("avg(salary) AS average_salary") + ), + "year", "department" + ) + ) + + expected_cube <- data.frame( + year = c(rep(NA, 4), rep(2016, 4), rep(2017, 4)), + department = rep(c(NA, "management", "rnd", "sales"), times = 3), + total_salary = c( + 120000, # Total + 10000 + 22000, 15000 + 32000, 20000 + 21000, # Department only + 20000 + 15000 + 10000, # 2016 + 10000, 15000, 20000, # 2016 each department + 21000 + 32000 + 22000, # 2017 + 22000, 32000, 21000 # 2017 each department + ), + average_salary = c( + # Total + mean(c(20000, 15000, 10000, 21000, 32000, 22000)), + # Mean by department + mean(c(10000, 22000)), mean(c(15000, 32000)), mean(c(20000, 21000)), + mean(c(10000, 15000, 20000)), # 2016 + 10000, 15000, 20000, # 2016 each department + mean(c(21000, 32000, 22000)), # 2017 + 22000, 32000, 21000 # 2017 each department + ), + stringsAsFactors = FALSE + ) + + expect_equal(actual_cube, expected_cube) + + # cube should accept column objects + expect_equal( + count(sum(cube(df, df$year, df$department), "salary")), + 12 + ) + + # cube without columns should result in a single aggregate + expect_equal( + collect(agg(cube(df), expr("sum(salary) as total_salary"))), + data.frame(total_salary = 120000) + ) + + actual_rollup <- collect( + orderBy( + agg( + rollup(df, "year", "department"), + expr("sum(salary) AS total_salary"), expr("avg(salary) AS average_salary") + ), + "year", "department" + ) + ) + + expected_rollup <- data.frame( + year = c(NA, rep(2016, 4), rep(2017, 4)), + department = c(NA, rep(c(NA, "management", "rnd", "sales"), times = 2)), + total_salary = c( + 120000, # Total + 20000 + 15000 + 10000, # 2016 + 10000, 15000, 20000, # 2016 each department + 21000 + 32000 + 22000, # 2017 + 22000, 32000, 21000 # 2017 each department + ), + average_salary = c( + # Total + mean(c(20000, 15000, 10000, 21000, 32000, 22000)), + mean(c(10000, 15000, 20000)), # 2016 + 10000, 15000, 20000, # 2016 each department + mean(c(21000, 32000, 22000)), # 2017 + 22000, 32000, 21000 # 2017 each department + ), + stringsAsFactors = FALSE + ) + + expect_equal(actual_rollup, expected_rollup) + + # cube should accept column objects + expect_equal( + count(sum(rollup(df, df$year, df$department), "salary")), + 9 + ) + + # rollup without columns should result in a single aggregate + expect_equal( + collect(agg(rollup(df), expr("sum(salary) as total_salary"))), + data.frame(total_salary = 120000) + ) +}) + test_that("arrange() and orderBy() on a DataFrame", { df <- read.json(jsonPath) sorted <- arrange(df, df$age) diff --git a/R/pkg/vignettes/sparkr-vignettes.Rmd b/R/pkg/vignettes/sparkr-vignettes.Rmd index f81dbab10b1e..4b9d6c380609 100644 --- a/R/pkg/vignettes/sparkr-vignettes.Rmd +++ b/R/pkg/vignettes/sparkr-vignettes.Rmd @@ -308,6 +308,21 @@ numCyl <- summarize(groupBy(carsDF, carsDF$cyl), count = n(carsDF$cyl)) head(numCyl) ``` +Use `cube` or `rollup` to compute subtotals across multiple dimensions. + +```{r} +mean(cube(carsDF, "cyl", "gear", "am"), "mpg") +``` + +generates groupings for {(`cyl`, `gear`, `am`), (`cyl`, `gear`), (`cyl`), ()}, while + +```{r} +mean(rollup(carsDF, "cyl", "gear", "am"), "mpg") +``` + +generates groupings for all possible combinations of grouping columns. + + #### Operating on Columns SparkR also provides a number of functions that can directly applied to columns for data processing and during aggregation. The example below shows the use of basic arithmetic functions. diff --git a/docs/sparkr.md b/docs/sparkr.md index a1a35a7757e5..e015ab260fca 100644 --- a/docs/sparkr.md +++ b/docs/sparkr.md @@ -264,6 +264,36 @@ head(arrange(waiting_counts, desc(waiting_counts$count))) {% endhighlight %} +In addition to standard aggregations, SparkR supports [OLAP cube](https://en.wikipedia.org/wiki/OLAP_cube) operators `cube`: + +
    +{% highlight r %} +head(agg(cube(df, "cyl", "disp", "gear"), avg(df$mpg))) +## cyl disp gear avg(mpg) +##1 NA 140.8 4 22.8 +##2 4 75.7 4 30.4 +##3 8 400.0 3 19.2 +##4 8 318.0 3 15.5 +##5 NA 351.0 NA 15.8 +##6 NA 275.8 NA 16.3 +{% endhighlight %} +
    + +and `rollup`: + +
    +{% highlight r %} +head(agg(rollup(df, "cyl", "disp", "gear"), avg(df$mpg))) +## cyl disp gear avg(mpg) +##1 4 75.7 4 30.4 +##2 8 400.0 3 19.2 +##3 8 318.0 3 15.5 +##4 4 78.7 NA 32.4 +##5 8 304.0 3 15.2 +##6 4 79.0 NA 27.3 +{% endhighlight %} +
    + ### Operating on Columns SparkR also provides a number of functions that can directly applied to columns for data processing and during aggregation. The example below shows the use of basic arithmetic functions. From 7a365257e934e838bd90f6a0c50362bf47202b0e Mon Sep 17 00:00:00 2001 From: anabranch Date: Wed, 26 Apr 2017 09:49:05 +0100 Subject: [PATCH 351/512] [SPARK-20400][DOCS] Remove References to 3rd Party Vendor Tools ## What changes were proposed in this pull request? Simple documentation change to remove explicit vendor references. ## How was this patch tested? NA Please review http://spark.apache.org/contributing.html before opening a pull request. Author: anabranch Closes #17695 from anabranch/remove-vendor. --- docs/configuration.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/configuration.md b/docs/configuration.md index 87b76322cae5..8b53e92ccd41 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -2270,8 +2270,8 @@ should be included on Spark's classpath: * `hdfs-site.xml`, which provides default behaviors for the HDFS client. * `core-site.xml`, which sets the default filesystem name. -The location of these configuration files varies across CDH and HDP versions, but -a common location is inside of `/etc/hadoop/conf`. Some tools, such as Cloudera Manager, create +The location of these configuration files varies across Hadoop versions, but +a common location is inside of `/etc/hadoop/conf`. Some tools create configurations on-the-fly, but offer a mechanisms to download copies of them. To make these files visible to Spark, set `HADOOP_CONF_DIR` in `$SPARK_HOME/spark-env.sh` From 7fecf5130163df9c204a2764d121a7011d007f4e Mon Sep 17 00:00:00 2001 From: Tom Graves Date: Wed, 26 Apr 2017 08:23:31 -0500 Subject: [PATCH 352/512] =?UTF-8?q?[SPARK-19812]=20YARN=20shuffle=20servic?= =?UTF-8?q?e=20fails=20to=20relocate=20recovery=20DB=20acro=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …ss NFS directories ## What changes were proposed in this pull request? Change from using java Files.move to use Hadoop filesystem operations to move the directories. The java Files.move does not work when moving directories across NFS mounts and in fact also says that if the directory has entries you should do a recursive move. We are already using Hadoop filesystem here so just use the local filesystem from there as it handles this properly. Note that the DB here is actually a directory of files and not just a single file, hence the change in the name of the local var. ## How was this patch tested? Ran YarnShuffleServiceSuite unit tests. Unfortunately couldn't easily add one here since involves NFS. Ran manual tests to verify that the DB directories were properly moved across NFS mounted directories. Have been running this internally for weeks. Author: Tom Graves Closes #17748 from tgravescs/SPARK-19812. --- .../network/yarn/YarnShuffleService.java | 23 +++++++++++-------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/common/network-yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java b/common/network-yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java index c7620d0fe128..4acc203153e5 100644 --- a/common/network-yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java +++ b/common/network-yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java @@ -21,7 +21,6 @@ import java.io.IOException; import java.nio.charset.StandardCharsets; import java.nio.ByteBuffer; -import java.nio.file.Files; import java.util.List; import java.util.Map; @@ -340,9 +339,9 @@ protected Path getRecoveryPath(String fileName) { * when it previously was not. If YARN NM recovery is enabled it uses that path, otherwise * it will uses a YARN local dir. */ - protected File initRecoveryDb(String dbFileName) { + protected File initRecoveryDb(String dbName) { if (_recoveryPath != null) { - File recoveryFile = new File(_recoveryPath.toUri().getPath(), dbFileName); + File recoveryFile = new File(_recoveryPath.toUri().getPath(), dbName); if (recoveryFile.exists()) { return recoveryFile; } @@ -350,7 +349,7 @@ protected File initRecoveryDb(String dbFileName) { // db doesn't exist in recovery path go check local dirs for it String[] localDirs = _conf.getTrimmedStrings("yarn.nodemanager.local-dirs"); for (String dir : localDirs) { - File f = new File(new Path(dir).toUri().getPath(), dbFileName); + File f = new File(new Path(dir).toUri().getPath(), dbName); if (f.exists()) { if (_recoveryPath == null) { // If NM recovery is not enabled, we should specify the recovery path using NM local @@ -363,17 +362,21 @@ protected File initRecoveryDb(String dbFileName) { // make sure to move all DBs to the recovery path from the old NM local dirs. // If another DB was initialized first just make sure all the DBs are in the same // location. - File newLoc = new File(_recoveryPath.toUri().getPath(), dbFileName); - if (!newLoc.equals(f)) { + Path newLoc = new Path(_recoveryPath, dbName); + Path copyFrom = new Path(f.toURI()); + if (!newLoc.equals(copyFrom)) { + logger.info("Moving " + copyFrom + " to: " + newLoc); try { - Files.move(f.toPath(), newLoc.toPath()); + // The move here needs to handle moving non-empty directories across NFS mounts + FileSystem fs = FileSystem.getLocal(_conf); + fs.rename(copyFrom, newLoc); } catch (Exception e) { // Fail to move recovery file to new path, just continue on with new DB location logger.error("Failed to move recovery file {} to the path {}", - dbFileName, _recoveryPath.toString(), e); + dbName, _recoveryPath.toString(), e); } } - return newLoc; + return new File(newLoc.toUri().getPath()); } } } @@ -381,7 +384,7 @@ protected File initRecoveryDb(String dbFileName) { _recoveryPath = new Path(localDirs[0]); } - return new File(_recoveryPath.toUri().getPath(), dbFileName); + return new File(_recoveryPath.toUri().getPath(), dbName); } /** From dbb06c689c157502cb081421baecce411832aad8 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Wed, 26 Apr 2017 21:34:18 +0800 Subject: [PATCH 353/512] [MINOR][ML] Fix some PySpark & SparkR flaky tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? Some PySpark & SparkR tests run with tiny dataset and tiny ```maxIter```, which means they are not converged. I don’t think checking intermediate result during iteration make sense, and these intermediate result may vulnerable and not stable, so we should switch to check the converged result. We hit this issue at #17746 when we upgrade breeze to 0.13.1. ## How was this patch tested? Existing tests. Author: Yanbo Liang Closes #17757 from yanboliang/flaky-test. --- .../testthat/test_mllib_classification.R | 17 +---- python/pyspark/ml/classification.py | 71 ++++++++++--------- 2 files changed, 38 insertions(+), 50 deletions(-) diff --git a/R/pkg/inst/tests/testthat/test_mllib_classification.R b/R/pkg/inst/tests/testthat/test_mllib_classification.R index af7cbdccf5d5..cbc708718286 100644 --- a/R/pkg/inst/tests/testthat/test_mllib_classification.R +++ b/R/pkg/inst/tests/testthat/test_mllib_classification.R @@ -284,22 +284,11 @@ test_that("spark.mlp", { c("1.0", "1.0", "1.0", "1.0", "0.0", "1.0", "2.0", "2.0", "1.0", "0.0")) # test initialWeights - model <- spark.mlp(df, label ~ features, layers = c(4, 3), maxIter = 2, initialWeights = + model <- spark.mlp(df, label ~ features, layers = c(4, 3), initialWeights = c(0, 0, 0, 0, 0, 5, 5, 5, 5, 5, 9, 9, 9, 9, 9)) mlpPredictions <- collect(select(predict(model, mlpTestDF), "prediction")) expect_equal(head(mlpPredictions$prediction, 10), - c("1.0", "1.0", "2.0", "1.0", "2.0", "1.0", "2.0", "2.0", "1.0", "0.0")) - - model <- spark.mlp(df, label ~ features, layers = c(4, 3), maxIter = 2, initialWeights = - c(0.0, 0.0, 0.0, 0.0, 0.0, 5.0, 5.0, 5.0, 5.0, 5.0, 9.0, 9.0, 9.0, 9.0, 9.0)) - mlpPredictions <- collect(select(predict(model, mlpTestDF), "prediction")) - expect_equal(head(mlpPredictions$prediction, 10), - c("1.0", "1.0", "2.0", "1.0", "2.0", "1.0", "2.0", "2.0", "1.0", "0.0")) - - model <- spark.mlp(df, label ~ features, layers = c(4, 3), maxIter = 2) - mlpPredictions <- collect(select(predict(model, mlpTestDF), "prediction")) - expect_equal(head(mlpPredictions$prediction, 10), - c("1.0", "1.0", "1.0", "1.0", "0.0", "1.0", "0.0", "0.0", "1.0", "0.0")) + c("1.0", "1.0", "1.0", "1.0", "0.0", "1.0", "2.0", "2.0", "1.0", "0.0")) # Test formula works well df <- suppressWarnings(createDataFrame(iris)) @@ -310,8 +299,6 @@ test_that("spark.mlp", { expect_equal(summary$numOfOutputs, 3) expect_equal(summary$layers, c(4, 3)) expect_equal(length(summary$weights), 15) - expect_equal(head(summary$weights, 5), list(-0.5793153, -4.652961, 6.216155, -6.649478, - -10.51147), tolerance = 1e-3) }) test_that("spark.naiveBayes", { diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 864968390ace..a9756ea4af99 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -185,34 +185,33 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti >>> from pyspark.sql import Row >>> from pyspark.ml.linalg import Vectors >>> bdf = sc.parallelize([ - ... Row(label=1.0, weight=2.0, features=Vectors.dense(1.0)), - ... Row(label=0.0, weight=2.0, features=Vectors.sparse(1, [], []))]).toDF() - >>> blor = LogisticRegression(maxIter=5, regParam=0.01, weightCol="weight") + ... Row(label=1.0, weight=1.0, features=Vectors.dense(0.0, 5.0)), + ... Row(label=0.0, weight=2.0, features=Vectors.dense(1.0, 2.0)), + ... Row(label=1.0, weight=3.0, features=Vectors.dense(2.0, 1.0)), + ... Row(label=0.0, weight=4.0, features=Vectors.dense(3.0, 3.0))]).toDF() + >>> blor = LogisticRegression(regParam=0.01, weightCol="weight") >>> blorModel = blor.fit(bdf) >>> blorModel.coefficients - DenseVector([5.4...]) + DenseVector([-1.080..., -0.646...]) >>> blorModel.intercept - -2.63... - >>> mdf = sc.parallelize([ - ... Row(label=1.0, weight=2.0, features=Vectors.dense(1.0)), - ... Row(label=0.0, weight=2.0, features=Vectors.sparse(1, [], [])), - ... Row(label=2.0, weight=2.0, features=Vectors.dense(3.0))]).toDF() - >>> mlor = LogisticRegression(maxIter=5, regParam=0.01, weightCol="weight", - ... family="multinomial") + 3.112... + >>> data_path = "data/mllib/sample_multiclass_classification_data.txt" + >>> mdf = spark.read.format("libsvm").load(data_path) + >>> mlor = LogisticRegression(regParam=0.1, elasticNetParam=1.0, family="multinomial") >>> mlorModel = mlor.fit(mdf) >>> mlorModel.coefficientMatrix - DenseMatrix(3, 1, [-2.3..., 0.2..., 2.1...], 1) + SparseMatrix(3, 4, [0, 1, 2, 3], [3, 2, 1], [1.87..., -2.75..., -0.50...], 1) >>> mlorModel.interceptVector - DenseVector([2.1..., 0.6..., -2.8...]) - >>> test0 = sc.parallelize([Row(features=Vectors.dense(-1.0))]).toDF() + DenseVector([0.04..., -0.42..., 0.37...]) + >>> test0 = sc.parallelize([Row(features=Vectors.dense(-1.0, 1.0))]).toDF() >>> result = blorModel.transform(test0).head() >>> result.prediction - 0.0 + 1.0 >>> result.probability - DenseVector([0.99..., 0.00...]) + DenseVector([0.02..., 0.97...]) >>> result.rawPrediction - DenseVector([8.12..., -8.12...]) - >>> test1 = sc.parallelize([Row(features=Vectors.sparse(1, [0], [1.0]))]).toDF() + DenseVector([-3.54..., 3.54...]) + >>> test1 = sc.parallelize([Row(features=Vectors.sparse(2, [0], [1.0]))]).toDF() >>> blorModel.transform(test1).head().prediction 1.0 >>> blor.setParams("vector") @@ -222,8 +221,8 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti >>> lr_path = temp_path + "/lr" >>> blor.save(lr_path) >>> lr2 = LogisticRegression.load(lr_path) - >>> lr2.getMaxIter() - 5 + >>> lr2.getRegParam() + 0.01 >>> model_path = temp_path + "/lr_model" >>> blorModel.save(model_path) >>> model2 = LogisticRegressionModel.load(model_path) @@ -1480,31 +1479,33 @@ class OneVsRest(Estimator, OneVsRestParams, MLReadable, MLWritable): >>> from pyspark.sql import Row >>> from pyspark.ml.linalg import Vectors - >>> df = sc.parallelize([ - ... Row(label=0.0, features=Vectors.dense(1.0, 0.8)), - ... Row(label=1.0, features=Vectors.sparse(2, [], [])), - ... Row(label=2.0, features=Vectors.dense(0.5, 0.5))]).toDF() - >>> lr = LogisticRegression(maxIter=5, regParam=0.01) + >>> data_path = "data/mllib/sample_multiclass_classification_data.txt" + >>> df = spark.read.format("libsvm").load(data_path) + >>> lr = LogisticRegression(regParam=0.01) >>> ovr = OneVsRest(classifier=lr) >>> model = ovr.fit(df) - >>> [x.coefficients for x in model.models] - [DenseVector([4.9791, 2.426]), DenseVector([-4.1198, -5.9326]), DenseVector([-3.314, 5.2423])] + >>> model.models[0].coefficients + DenseVector([0.5..., -1.0..., 3.4..., 4.2...]) + >>> model.models[1].coefficients + DenseVector([-2.1..., 3.1..., -2.6..., -2.3...]) + >>> model.models[2].coefficients + DenseVector([0.3..., -3.4..., 1.0..., -1.1...]) >>> [x.intercept for x in model.models] - [-5.06544..., 2.30341..., -1.29133...] - >>> test0 = sc.parallelize([Row(features=Vectors.dense(-1.0, 0.0))]).toDF() + [-2.7..., -2.5..., -1.3...] + >>> test0 = sc.parallelize([Row(features=Vectors.dense(-1.0, 0.0, 1.0, 1.0))]).toDF() >>> model.transform(test0).head().prediction - 1.0 - >>> test1 = sc.parallelize([Row(features=Vectors.sparse(2, [0], [1.0]))]).toDF() - >>> model.transform(test1).head().prediction 0.0 - >>> test2 = sc.parallelize([Row(features=Vectors.dense(0.5, 0.4))]).toDF() - >>> model.transform(test2).head().prediction + >>> test1 = sc.parallelize([Row(features=Vectors.sparse(4, [0], [1.0]))]).toDF() + >>> model.transform(test1).head().prediction 2.0 + >>> test2 = sc.parallelize([Row(features=Vectors.dense(0.5, 0.4, 0.3, 0.2))]).toDF() + >>> model.transform(test2).head().prediction + 0.0 >>> model_path = temp_path + "/ovr_model" >>> model.save(model_path) >>> model2 = OneVsRestModel.load(model_path) >>> model2.transform(test0).head().prediction - 1.0 + 0.0 .. versionadded:: 2.0.0 """ From 66dd5b83ff95d5f91f37dcdf6aac89faa0b871c5 Mon Sep 17 00:00:00 2001 From: jerryshao Date: Wed, 26 Apr 2017 09:01:50 -0500 Subject: [PATCH 354/512] [SPARK-20391][CORE] Rename memory related fields in ExecutorSummay ## What changes were proposed in this pull request? This is a follow-up of #14617 to make the name of memory related fields more meaningful. Here for the backward compatibility, I didn't change `maxMemory` and `memoryUsed` fields. ## How was this patch tested? Existing UT and local verification. CC squito and tgravescs . Author: jerryshao Closes #17700 from jerryshao/SPARK-20391. --- .../apache/spark/ui/static/executorspage.js | 48 +++++++++-------- .../org/apache/spark/status/api/v1/api.scala | 11 ++-- .../apache/spark/ui/exec/ExecutorsPage.scala | 21 ++++---- .../executor_memory_usage_expectation.json | 51 +++++++++++-------- ...xecutor_node_blacklisting_expectation.json | 51 +++++++++++-------- 5 files changed, 105 insertions(+), 77 deletions(-) diff --git a/core/src/main/resources/org/apache/spark/ui/static/executorspage.js b/core/src/main/resources/org/apache/spark/ui/static/executorspage.js index 930a0698928d..cb9922d23c44 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/executorspage.js +++ b/core/src/main/resources/org/apache/spark/ui/static/executorspage.js @@ -253,10 +253,14 @@ $(document).ready(function () { var deadTotalBlacklisted = 0; response.forEach(function (exec) { - exec.onHeapMemoryUsed = exec.hasOwnProperty('onHeapMemoryUsed') ? exec.onHeapMemoryUsed : 0; - exec.maxOnHeapMemory = exec.hasOwnProperty('maxOnHeapMemory') ? exec.maxOnHeapMemory : 0; - exec.offHeapMemoryUsed = exec.hasOwnProperty('offHeapMemoryUsed') ? exec.offHeapMemoryUsed : 0; - exec.maxOffHeapMemory = exec.hasOwnProperty('maxOffHeapMemory') ? exec.maxOffHeapMemory : 0; + var memoryMetrics = { + usedOnHeapStorageMemory: 0, + usedOffHeapStorageMemory: 0, + totalOnHeapStorageMemory: 0, + totalOffHeapStorageMemory: 0 + }; + + exec.memoryMetrics = exec.hasOwnProperty('memoryMetrics') ? exec.memoryMetrics : memoryMetrics; }); response.forEach(function (exec) { @@ -264,10 +268,10 @@ $(document).ready(function () { allRDDBlocks += exec.rddBlocks; allMemoryUsed += exec.memoryUsed; allMaxMemory += exec.maxMemory; - allOnHeapMemoryUsed += exec.onHeapMemoryUsed; - allOnHeapMaxMemory += exec.maxOnHeapMemory; - allOffHeapMemoryUsed += exec.offHeapMemoryUsed; - allOffHeapMaxMemory += exec.maxOffHeapMemory; + allOnHeapMemoryUsed += exec.memoryMetrics.usedOnHeapStorageMemory; + allOnHeapMaxMemory += exec.memoryMetrics.totalOnHeapStorageMemory; + allOffHeapMemoryUsed += exec.memoryMetrics.usedOffHeapStorageMemory; + allOffHeapMaxMemory += exec.memoryMetrics.totalOffHeapStorageMemory; allDiskUsed += exec.diskUsed; allTotalCores += exec.totalCores; allMaxTasks += exec.maxTasks; @@ -286,10 +290,10 @@ $(document).ready(function () { activeRDDBlocks += exec.rddBlocks; activeMemoryUsed += exec.memoryUsed; activeMaxMemory += exec.maxMemory; - activeOnHeapMemoryUsed += exec.onHeapMemoryUsed; - activeOnHeapMaxMemory += exec.maxOnHeapMemory; - activeOffHeapMemoryUsed += exec.offHeapMemoryUsed; - activeOffHeapMaxMemory += exec.maxOffHeapMemory; + activeOnHeapMemoryUsed += exec.memoryMetrics.usedOnHeapStorageMemory; + activeOnHeapMaxMemory += exec.memoryMetrics.totalOnHeapStorageMemory; + activeOffHeapMemoryUsed += exec.memoryMetrics.usedOffHeapStorageMemory; + activeOffHeapMaxMemory += exec.memoryMetrics.totalOffHeapStorageMemory; activeDiskUsed += exec.diskUsed; activeTotalCores += exec.totalCores; activeMaxTasks += exec.maxTasks; @@ -308,10 +312,10 @@ $(document).ready(function () { deadRDDBlocks += exec.rddBlocks; deadMemoryUsed += exec.memoryUsed; deadMaxMemory += exec.maxMemory; - deadOnHeapMemoryUsed += exec.onHeapMemoryUsed; - deadOnHeapMaxMemory += exec.maxOnHeapMemory; - deadOffHeapMemoryUsed += exec.offHeapMemoryUsed; - deadOffHeapMaxMemory += exec.maxOffHeapMemory; + deadOnHeapMemoryUsed += exec.memoryMetrics.usedOnHeapStorageMemory; + deadOnHeapMaxMemory += exec.memoryMetrics.totalOnHeapStorageMemory; + deadOffHeapMemoryUsed += exec.memoryMetrics.usedOffHeapStorageMemory; + deadOffHeapMaxMemory += exec.memoryMetrics.totalOffHeapStorageMemory; deadDiskUsed += exec.diskUsed; deadTotalCores += exec.totalCores; deadMaxTasks += exec.maxTasks; @@ -431,10 +435,10 @@ $(document).ready(function () { { data: function (row, type) { if (type !== 'display') - return row.onHeapMemoryUsed; + return row.memoryMetrics.usedOnHeapStorageMemory; else - return (formatBytes(row.onHeapMemoryUsed, type) + ' / ' + - formatBytes(row.maxOnHeapMemory, type)); + return (formatBytes(row.memoryMetrics.usedOnHeapStorageMemory, type) + ' / ' + + formatBytes(row.memoryMetrics.totalOnHeapStorageMemory, type)); }, "fnCreatedCell": function (nTd, sData, oData, iRow, iCol) { $(nTd).addClass('on_heap_memory') @@ -443,10 +447,10 @@ $(document).ready(function () { { data: function (row, type) { if (type !== 'display') - return row.offHeapMemoryUsed; + return row.memoryMetrics.usedOffHeapStorageMemory; else - return (formatBytes(row.offHeapMemoryUsed, type) + ' / ' + - formatBytes(row.maxOffHeapMemory, type)); + return (formatBytes(row.memoryMetrics.usedOffHeapStorageMemory, type) + ' / ' + + formatBytes(row.memoryMetrics.totalOffHeapStorageMemory, type)); }, "fnCreatedCell": function (nTd, sData, oData, iRow, iCol) { $(nTd).addClass('off_heap_memory') diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala index d159b9450ef5..56d8e51732ff 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala @@ -76,10 +76,13 @@ class ExecutorSummary private[spark]( val isBlacklisted: Boolean, val maxMemory: Long, val executorLogs: Map[String, String], - val onHeapMemoryUsed: Option[Long], - val offHeapMemoryUsed: Option[Long], - val maxOnHeapMemory: Option[Long], - val maxOffHeapMemory: Option[Long]) + val memoryMetrics: Option[MemoryMetrics]) + +class MemoryMetrics private[spark]( + val usedOnHeapStorageMemory: Long, + val usedOffHeapStorageMemory: Long, + val totalOnHeapStorageMemory: Long, + val totalOffHeapStorageMemory: Long) class JobData private[spark]( val jobId: Int, diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala index 0a3c63d14ca8..b7cbed468517 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala @@ -21,7 +21,7 @@ import javax.servlet.http.HttpServletRequest import scala.xml.Node -import org.apache.spark.status.api.v1.ExecutorSummary +import org.apache.spark.status.api.v1.{ExecutorSummary, MemoryMetrics} import org.apache.spark.ui.{UIUtils, WebUIPage} // This isn't even used anymore -- but we need to keep it b/c of a MiMa false positive @@ -114,10 +114,16 @@ private[spark] object ExecutorsPage { val rddBlocks = status.numBlocks val memUsed = status.memUsed val maxMem = status.maxMem - val onHeapMemUsed = status.onHeapMemUsed - val offHeapMemUsed = status.offHeapMemUsed - val maxOnHeapMem = status.maxOnHeapMem - val maxOffHeapMem = status.maxOffHeapMem + val memoryMetrics = for { + onHeapUsed <- status.onHeapMemUsed + offHeapUsed <- status.offHeapMemUsed + maxOnHeap <- status.maxOnHeapMem + maxOffHeap <- status.maxOffHeapMem + } yield { + new MemoryMetrics(onHeapUsed, offHeapUsed, maxOnHeap, maxOffHeap) + } + + val diskUsed = status.diskUsed val taskSummary = listener.executorToTaskSummary.getOrElse(execId, ExecutorTaskSummary(execId)) @@ -142,10 +148,7 @@ private[spark] object ExecutorsPage { taskSummary.isBlacklisted, maxMem, taskSummary.executorLogs, - onHeapMemUsed, - offHeapMemUsed, - maxOnHeapMem, - maxOffHeapMem + memoryMetrics ) } } diff --git a/core/src/test/resources/HistoryServerExpectations/executor_memory_usage_expectation.json b/core/src/test/resources/HistoryServerExpectations/executor_memory_usage_expectation.json index e732af266350..0f94e3b255db 100644 --- a/core/src/test/resources/HistoryServerExpectations/executor_memory_usage_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/executor_memory_usage_expectation.json @@ -22,10 +22,12 @@ "stdout" : "http://172.22.0.167:51469/logPage/?appId=app-20161116163331-0000&executorId=2&logType=stdout", "stderr" : "http://172.22.0.167:51469/logPage/?appId=app-20161116163331-0000&executorId=2&logType=stderr" }, - "onHeapMemoryUsed" : 0, - "offHeapMemoryUsed" : 0, - "maxOnHeapMemory" : 384093388, - "maxOffHeapMemory" : 524288000 + "memoryMetrics": { + "usedOnHeapStorageMemory": 0, + "usedOffHeapStorageMemory": 0, + "totalOnHeapStorageMemory": 384093388, + "totalOffHeapStorageMemory": 524288000 + } }, { "id" : "driver", "hostPort" : "172.22.0.167:51475", @@ -47,10 +49,12 @@ "isBlacklisted" : true, "maxMemory" : 908381388, "executorLogs" : { }, - "onHeapMemoryUsed" : 0, - "offHeapMemoryUsed" : 0, - "maxOnHeapMemory" : 384093388, - "maxOffHeapMemory" : 524288000 + "memoryMetrics": { + "usedOnHeapStorageMemory": 0, + "usedOffHeapStorageMemory": 0, + "totalOnHeapStorageMemory": 384093388, + "totalOffHeapStorageMemory": 524288000 + } }, { "id" : "1", "hostPort" : "172.22.0.167:51490", @@ -75,11 +79,12 @@ "stdout" : "http://172.22.0.167:51467/logPage/?appId=app-20161116163331-0000&executorId=1&logType=stdout", "stderr" : "http://172.22.0.167:51467/logPage/?appId=app-20161116163331-0000&executorId=1&logType=stderr" }, - - "onHeapMemoryUsed" : 0, - "offHeapMemoryUsed" : 0, - "maxOnHeapMemory" : 384093388, - "maxOffHeapMemory" : 524288000 + "memoryMetrics": { + "usedOnHeapStorageMemory": 0, + "usedOffHeapStorageMemory": 0, + "totalOnHeapStorageMemory": 384093388, + "totalOffHeapStorageMemory": 524288000 + } }, { "id" : "0", "hostPort" : "172.22.0.167:51491", @@ -104,10 +109,12 @@ "stdout" : "http://172.22.0.167:51465/logPage/?appId=app-20161116163331-0000&executorId=0&logType=stdout", "stderr" : "http://172.22.0.167:51465/logPage/?appId=app-20161116163331-0000&executorId=0&logType=stderr" }, - "onHeapMemoryUsed" : 0, - "offHeapMemoryUsed" : 0, - "maxOnHeapMemory" : 384093388, - "maxOffHeapMemory" : 524288000 + "memoryMetrics": { + "usedOnHeapStorageMemory": 0, + "usedOffHeapStorageMemory": 0, + "totalOnHeapStorageMemory": 384093388, + "totalOffHeapStorageMemory": 524288000 + } }, { "id" : "3", "hostPort" : "172.22.0.167:51485", @@ -132,8 +139,10 @@ "stdout" : "http://172.22.0.167:51466/logPage/?appId=app-20161116163331-0000&executorId=3&logType=stdout", "stderr" : "http://172.22.0.167:51466/logPage/?appId=app-20161116163331-0000&executorId=3&logType=stderr" }, - "onHeapMemoryUsed" : 0, - "offHeapMemoryUsed" : 0, - "maxOnHeapMemory" : 384093388, - "maxOffHeapMemory" : 524288000 + "memoryMetrics": { + "usedOnHeapStorageMemory": 0, + "usedOffHeapStorageMemory": 0, + "totalOnHeapStorageMemory": 384093388, + "totalOffHeapStorageMemory": 524288000 + } } ] diff --git a/core/src/test/resources/HistoryServerExpectations/executor_node_blacklisting_expectation.json b/core/src/test/resources/HistoryServerExpectations/executor_node_blacklisting_expectation.json index e732af266350..0f94e3b255db 100644 --- a/core/src/test/resources/HistoryServerExpectations/executor_node_blacklisting_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/executor_node_blacklisting_expectation.json @@ -22,10 +22,12 @@ "stdout" : "http://172.22.0.167:51469/logPage/?appId=app-20161116163331-0000&executorId=2&logType=stdout", "stderr" : "http://172.22.0.167:51469/logPage/?appId=app-20161116163331-0000&executorId=2&logType=stderr" }, - "onHeapMemoryUsed" : 0, - "offHeapMemoryUsed" : 0, - "maxOnHeapMemory" : 384093388, - "maxOffHeapMemory" : 524288000 + "memoryMetrics": { + "usedOnHeapStorageMemory": 0, + "usedOffHeapStorageMemory": 0, + "totalOnHeapStorageMemory": 384093388, + "totalOffHeapStorageMemory": 524288000 + } }, { "id" : "driver", "hostPort" : "172.22.0.167:51475", @@ -47,10 +49,12 @@ "isBlacklisted" : true, "maxMemory" : 908381388, "executorLogs" : { }, - "onHeapMemoryUsed" : 0, - "offHeapMemoryUsed" : 0, - "maxOnHeapMemory" : 384093388, - "maxOffHeapMemory" : 524288000 + "memoryMetrics": { + "usedOnHeapStorageMemory": 0, + "usedOffHeapStorageMemory": 0, + "totalOnHeapStorageMemory": 384093388, + "totalOffHeapStorageMemory": 524288000 + } }, { "id" : "1", "hostPort" : "172.22.0.167:51490", @@ -75,11 +79,12 @@ "stdout" : "http://172.22.0.167:51467/logPage/?appId=app-20161116163331-0000&executorId=1&logType=stdout", "stderr" : "http://172.22.0.167:51467/logPage/?appId=app-20161116163331-0000&executorId=1&logType=stderr" }, - - "onHeapMemoryUsed" : 0, - "offHeapMemoryUsed" : 0, - "maxOnHeapMemory" : 384093388, - "maxOffHeapMemory" : 524288000 + "memoryMetrics": { + "usedOnHeapStorageMemory": 0, + "usedOffHeapStorageMemory": 0, + "totalOnHeapStorageMemory": 384093388, + "totalOffHeapStorageMemory": 524288000 + } }, { "id" : "0", "hostPort" : "172.22.0.167:51491", @@ -104,10 +109,12 @@ "stdout" : "http://172.22.0.167:51465/logPage/?appId=app-20161116163331-0000&executorId=0&logType=stdout", "stderr" : "http://172.22.0.167:51465/logPage/?appId=app-20161116163331-0000&executorId=0&logType=stderr" }, - "onHeapMemoryUsed" : 0, - "offHeapMemoryUsed" : 0, - "maxOnHeapMemory" : 384093388, - "maxOffHeapMemory" : 524288000 + "memoryMetrics": { + "usedOnHeapStorageMemory": 0, + "usedOffHeapStorageMemory": 0, + "totalOnHeapStorageMemory": 384093388, + "totalOffHeapStorageMemory": 524288000 + } }, { "id" : "3", "hostPort" : "172.22.0.167:51485", @@ -132,8 +139,10 @@ "stdout" : "http://172.22.0.167:51466/logPage/?appId=app-20161116163331-0000&executorId=3&logType=stdout", "stderr" : "http://172.22.0.167:51466/logPage/?appId=app-20161116163331-0000&executorId=3&logType=stderr" }, - "onHeapMemoryUsed" : 0, - "offHeapMemoryUsed" : 0, - "maxOnHeapMemory" : 384093388, - "maxOffHeapMemory" : 524288000 + "memoryMetrics": { + "usedOnHeapStorageMemory": 0, + "usedOffHeapStorageMemory": 0, + "totalOnHeapStorageMemory": 384093388, + "totalOffHeapStorageMemory": 524288000 + } } ] From 99c6cf9ef16bf8fae6edb23a62e46546a16bca80 Mon Sep 17 00:00:00 2001 From: Michal Szafranski Date: Wed, 26 Apr 2017 11:21:25 -0700 Subject: [PATCH 355/512] [SPARK-20473] Enabling missing types in ColumnVector.Array ## What changes were proposed in this pull request? ColumnVector implementations originally did not support some Catalyst types (float, short, and boolean). Now that they do, those types should be also added to the ColumnVector.Array. ## How was this patch tested? Tested using existing unit tests. Author: Michal Szafranski Closes #17772 from michal-databricks/spark-20473. --- .../apache/spark/sql/execution/vectorized/ColumnVector.java | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java index 354c878aca00..b105e60a2d34 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java @@ -180,7 +180,7 @@ public Object[] array() { @Override public boolean getBoolean(int ordinal) { - throw new UnsupportedOperationException(); + return data.getBoolean(offset + ordinal); } @Override @@ -188,7 +188,7 @@ public boolean getBoolean(int ordinal) { @Override public short getShort(int ordinal) { - throw new UnsupportedOperationException(); + return data.getShort(offset + ordinal); } @Override @@ -199,7 +199,7 @@ public short getShort(int ordinal) { @Override public float getFloat(int ordinal) { - throw new UnsupportedOperationException(); + return data.getFloat(offset + ordinal); } @Override From a277ae80a2836e6533b338d2b9c4e59ed8a1daae Mon Sep 17 00:00:00 2001 From: Michal Szafranski Date: Wed, 26 Apr 2017 12:47:37 -0700 Subject: [PATCH 356/512] [SPARK-20474] Fixing OnHeapColumnVector reallocation ## What changes were proposed in this pull request? OnHeapColumnVector reallocation copies to the new storage data up to 'elementsAppended'. This variable is only updated when using the ColumnVector.appendX API, while ColumnVector.putX is more commonly used. ## How was this patch tested? Tested using existing unit tests. Author: Michal Szafranski Closes #17773 from michal-databricks/spark-20474. --- .../vectorized/OnHeapColumnVector.java | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java index 9b410bacff5d..94ed32294cfa 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java @@ -410,53 +410,53 @@ protected void reserveInternal(int newCapacity) { int[] newLengths = new int[newCapacity]; int[] newOffsets = new int[newCapacity]; if (this.arrayLengths != null) { - System.arraycopy(this.arrayLengths, 0, newLengths, 0, elementsAppended); - System.arraycopy(this.arrayOffsets, 0, newOffsets, 0, elementsAppended); + System.arraycopy(this.arrayLengths, 0, newLengths, 0, capacity); + System.arraycopy(this.arrayOffsets, 0, newOffsets, 0, capacity); } arrayLengths = newLengths; arrayOffsets = newOffsets; } else if (type instanceof BooleanType) { if (byteData == null || byteData.length < newCapacity) { byte[] newData = new byte[newCapacity]; - if (byteData != null) System.arraycopy(byteData, 0, newData, 0, elementsAppended); + if (byteData != null) System.arraycopy(byteData, 0, newData, 0, capacity); byteData = newData; } } else if (type instanceof ByteType) { if (byteData == null || byteData.length < newCapacity) { byte[] newData = new byte[newCapacity]; - if (byteData != null) System.arraycopy(byteData, 0, newData, 0, elementsAppended); + if (byteData != null) System.arraycopy(byteData, 0, newData, 0, capacity); byteData = newData; } } else if (type instanceof ShortType) { if (shortData == null || shortData.length < newCapacity) { short[] newData = new short[newCapacity]; - if (shortData != null) System.arraycopy(shortData, 0, newData, 0, elementsAppended); + if (shortData != null) System.arraycopy(shortData, 0, newData, 0, capacity); shortData = newData; } } else if (type instanceof IntegerType || type instanceof DateType || DecimalType.is32BitDecimalType(type)) { if (intData == null || intData.length < newCapacity) { int[] newData = new int[newCapacity]; - if (intData != null) System.arraycopy(intData, 0, newData, 0, elementsAppended); + if (intData != null) System.arraycopy(intData, 0, newData, 0, capacity); intData = newData; } } else if (type instanceof LongType || type instanceof TimestampType || DecimalType.is64BitDecimalType(type)) { if (longData == null || longData.length < newCapacity) { long[] newData = new long[newCapacity]; - if (longData != null) System.arraycopy(longData, 0, newData, 0, elementsAppended); + if (longData != null) System.arraycopy(longData, 0, newData, 0, capacity); longData = newData; } } else if (type instanceof FloatType) { if (floatData == null || floatData.length < newCapacity) { float[] newData = new float[newCapacity]; - if (floatData != null) System.arraycopy(floatData, 0, newData, 0, elementsAppended); + if (floatData != null) System.arraycopy(floatData, 0, newData, 0, capacity); floatData = newData; } } else if (type instanceof DoubleType) { if (doubleData == null || doubleData.length < newCapacity) { double[] newData = new double[newCapacity]; - if (doubleData != null) System.arraycopy(doubleData, 0, newData, 0, elementsAppended); + if (doubleData != null) System.arraycopy(doubleData, 0, newData, 0, capacity); doubleData = newData; } } else if (resultStruct != null) { @@ -466,7 +466,7 @@ protected void reserveInternal(int newCapacity) { } byte[] newNulls = new byte[newCapacity]; - if (nulls != null) System.arraycopy(nulls, 0, newNulls, 0, elementsAppended); + if (nulls != null) System.arraycopy(nulls, 0, newNulls, 0, capacity); nulls = newNulls; capacity = newCapacity; From 2ba1eba371213d1ac3d1fa1552e5906e043c2ee4 Mon Sep 17 00:00:00 2001 From: Weiqing Yang Date: Wed, 26 Apr 2017 13:54:40 -0700 Subject: [PATCH 357/512] [SPARK-12868][SQL] Allow adding jars from hdfs ## What changes were proposed in this pull request? Spark 2.2 is going to be cut, it'll be great if SPARK-12868 can be resolved before that. There have been several PRs for this like [PR#16324](https://github.com/apache/spark/pull/16324) , but all of them are inactivity for a long time or have been closed. This PR added a SparkUrlStreamHandlerFactory, which relies on 'protocol' to choose the appropriate UrlStreamHandlerFactory like FsUrlStreamHandlerFactory to create URLStreamHandler. ## How was this patch tested? 1. Add a new unit test. 2. Check manually. Before: throw an exception with " failed unknown protocol: hdfs" screen shot 2017-03-17 at 9 07 36 pm After: screen shot 2017-03-18 at 11 42 18 am Author: Weiqing Yang Closes #17342 from weiqingy/SPARK-18910. --- .../org/apache/spark/sql/internal/SharedState.scala | 10 +++++++++- .../scala/org/apache/spark/sql/SQLQuerySuite.scala | 13 +++++++++++++ 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala index f834569e59b7..a93b70114607 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala @@ -17,12 +17,14 @@ package org.apache.spark.sql.internal +import java.net.URL import java.util.Locale import scala.reflect.ClassTag import scala.util.control.NonFatal import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.FsUrlStreamHandlerFactory import org.apache.spark.{SparkConf, SparkContext, SparkException} import org.apache.spark.internal.Logging @@ -154,7 +156,13 @@ private[sql] class SharedState(val sparkContext: SparkContext) extends Logging { } } -object SharedState { +object SharedState extends Logging { + try { + URL.setURLStreamHandlerFactory(new FsUrlStreamHandlerFactory()) + } catch { + case e: Error => + logWarning("URL.setURLStreamHandlerFactory failed to set FsUrlStreamHandlerFactory") + } private val HIVE_EXTERNAL_CATALOG_CLASS_NAME = "org.apache.spark.sql.hive.HiveExternalCatalog" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 0dd9296a3f0f..3ecbf96b4196 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql import java.io.File import java.math.MathContext +import java.net.{MalformedURLException, URL} import java.sql.Timestamp import java.util.concurrent.atomic.AtomicBoolean @@ -2606,4 +2607,16 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { case ae: AnalysisException => assert(ae.plan == null && ae.getMessage == ae.getSimpleMessage) } } + + test("SPARK-12868: Allow adding jars from hdfs ") { + val jarFromHdfs = "hdfs://doesnotmatter/test.jar" + val jarFromInvalidFs = "fffs://doesnotmatter/test.jar" + + // if 'hdfs' is not supported, MalformedURLException will be thrown + new URL(jarFromHdfs) + + intercept[MalformedURLException] { + new URL(jarFromInvalidFs) + } + } } From 66636ef0b046e5d1f340c3b8153d7213fa9d19c7 Mon Sep 17 00:00:00 2001 From: Mark Grover Date: Wed, 26 Apr 2017 17:06:21 -0700 Subject: [PATCH 358/512] [SPARK-20435][CORE] More thorough redaction of sensitive information This change does a more thorough redaction of sensitive information from logs and UI Add unit tests that ensure that no regressions happen that leak sensitive information to the logs. The motivation for this change was appearance of password like so in `SparkListenerEnvironmentUpdate` in event logs under some JVM configurations: `"sun.java.command":"org.apache.spark.deploy.SparkSubmit ... --conf spark.executorEnv.HADOOP_CREDSTORE_PASSWORD=secret_password ..." ` Previously redaction logic was only checking if the key matched the secret regex pattern, it'd redact it's value. That worked for most cases. However, in the above case, the key (sun.java.command) doesn't tell much, so the value needs to be searched. This PR expands the check to check for values as well. ## How was this patch tested? New unit tests added that ensure that no sensitive information is present in the event logs or the yarn logs. Old unit test in UtilsSuite was modified because the test was asserting that a non-sensitive property's value won't be redacted. However, the non-sensitive value had the literal "secret" in it which was causing it to redact. Simply updating the non-sensitive property's value to another arbitrary value (that didn't have "secret" in it) fixed it. Author: Mark Grover Closes #17725 from markgrover/spark-20435. --- .../spark/internal/config/package.scala | 4 +-- .../scheduler/EventLoggingListener.scala | 16 ++++++--- .../scala/org/apache/spark/util/Utils.scala | 22 +++++++++--- .../spark/deploy/SparkSubmitSuite.scala | 34 +++++++++++++++++++ .../org/apache/spark/util/UtilsSuite.scala | 10 ++++-- docs/configuration.md | 4 +-- .../spark/deploy/yarn/YarnClusterSuite.scala | 32 +++++++++++++---- 7 files changed, 100 insertions(+), 22 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 89aeea493908..2f0a3064be11 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -244,8 +244,8 @@ package object config { ConfigBuilder("spark.redaction.regex") .doc("Regex to decide which Spark configuration properties and environment variables in " + "driver and executor environments contain sensitive information. When this regex matches " + - "a property, its value is redacted from the environment UI and various logs like YARN " + - "and event logs.") + "a property key or value, the value is redacted from the environment UI and various logs " + + "like YARN and event logs.") .regexConf .createWithDefault("(?i)secret|password".r) diff --git a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala index aecb3a980e7c..a7dbf87915b2 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala @@ -252,11 +252,17 @@ private[spark] class EventLoggingListener( private[spark] def redactEvent( event: SparkListenerEnvironmentUpdate): SparkListenerEnvironmentUpdate = { - // "Spark Properties" entry will always exist because the map is always populated with it. - val redactedProps = Utils.redact(sparkConf, event.environmentDetails("Spark Properties")) - val redactedEnvironmentDetails = event.environmentDetails + - ("Spark Properties" -> redactedProps) - SparkListenerEnvironmentUpdate(redactedEnvironmentDetails) + // environmentDetails maps a string descriptor to a set of properties + // Similar to: + // "JVM Information" -> jvmInformation, + // "Spark Properties" -> sparkProperties, + // ... + // where jvmInformation, sparkProperties, etc. are sequence of tuples. + // We go through the various of properties and redact sensitive information from them. + val redactedProps = event.environmentDetails.map{ case (name, props) => + name -> Utils.redact(sparkConf, props) + } + SparkListenerEnvironmentUpdate(redactedProps) } } diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 943dde072327..e042badcdd4a 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -2606,10 +2606,24 @@ private[spark] object Utils extends Logging { } private def redact(redactionPattern: Regex, kvs: Seq[(String, String)]): Seq[(String, String)] = { - kvs.map { kv => - redactionPattern.findFirstIn(kv._1) - .map { _ => (kv._1, REDACTION_REPLACEMENT_TEXT) } - .getOrElse(kv) + // If the sensitive information regex matches with either the key or the value, redact the value + // While the original intent was to only redact the value if the key matched with the regex, + // we've found that especially in verbose mode, the value of the property may contain sensitive + // information like so: + // "sun.java.command":"org.apache.spark.deploy.SparkSubmit ... \ + // --conf spark.executorEnv.HADOOP_CREDSTORE_PASSWORD=secret_password ... + // + // And, in such cases, simply searching for the sensitive information regex in the key name is + // not sufficient. The values themselves have to be searched as well and redacted if matched. + // This does mean we may be accounting more false positives - for example, if the value of an + // arbitrary property contained the term 'password', we may redact the value from the UI and + // logs. In order to work around it, user would have to make the spark.redaction.regex property + // more specific. + kvs.map { case (key, value) => + redactionPattern.findFirstIn(key) + .orElse(redactionPattern.findFirstIn(value)) + .map { _ => (key, REDACTION_REPLACEMENT_TEXT) } + .getOrElse((key, value)) } } diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 7c2ec01a03d0..a43839a8815f 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -21,8 +21,10 @@ import java.io._ import java.nio.charset.StandardCharsets import scala.collection.mutable.ArrayBuffer +import scala.io.Source import com.google.common.io.ByteStreams +import org.apache.hadoop.fs.Path import org.scalatest.{BeforeAndAfterEach, Matchers} import org.scalatest.concurrent.Timeouts import org.scalatest.time.SpanSugar._ @@ -34,6 +36,7 @@ import org.apache.spark.deploy.SparkSubmitUtils.MavenCoordinate import org.apache.spark.internal.config._ import org.apache.spark.internal.Logging import org.apache.spark.TestUtils.JavaSourceFromString +import org.apache.spark.scheduler.EventLoggingListener import org.apache.spark.util.{CommandLineUtils, ResetSystemProperties, Utils} @@ -404,6 +407,37 @@ class SparkSubmitSuite runSparkSubmit(args) } + test("launch simple application with spark-submit with redaction") { + val testDir = Utils.createTempDir() + testDir.deleteOnExit() + val testDirPath = new Path(testDir.getAbsolutePath()) + val unusedJar = TestUtils.createJarWithClasses(Seq.empty) + val fileSystem = Utils.getHadoopFileSystem("/", + SparkHadoopUtil.get.newConfiguration(new SparkConf())) + try { + val args = Seq( + "--class", SimpleApplicationTest.getClass.getName.stripSuffix("$"), + "--name", "testApp", + "--master", "local", + "--conf", "spark.ui.enabled=false", + "--conf", "spark.master.rest.enabled=false", + "--conf", "spark.executorEnv.HADOOP_CREDSTORE_PASSWORD=secret_password", + "--conf", "spark.eventLog.enabled=true", + "--conf", "spark.eventLog.testing=true", + "--conf", s"spark.eventLog.dir=${testDirPath.toUri.toString}", + "--conf", "spark.hadoop.fs.defaultFS=unsupported://example.com", + unusedJar.toString) + runSparkSubmit(args) + val listStatus = fileSystem.listStatus(testDirPath) + val logData = EventLoggingListener.openEventLog(listStatus.last.getPath, fileSystem) + Source.fromInputStream(logData).getLines().foreach { line => + assert(!line.contains("secret_password")) + } + } finally { + Utils.deleteRecursively(testDir) + } + } + test("includes jars passed in through --jars") { val unusedJar = TestUtils.createJarWithClasses(Seq.empty) val jar1 = TestUtils.createJarWithClasses(Seq("SparkSubmitClassA")) diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index 8ed09749ffd5..3339d5b35d3b 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -1010,15 +1010,19 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { "spark.executorEnv.HADOOP_CREDSTORE_PASSWORD", "spark.my.password", "spark.my.sECreT") - secretKeys.foreach { key => sparkConf.set(key, "secret_password") } + secretKeys.foreach { key => sparkConf.set(key, "sensitive_value") } // Set a non-secret key - sparkConf.set("spark.regular.property", "not_a_secret") + sparkConf.set("spark.regular.property", "regular_value") + // Set a property with a regular key but secret in the value + sparkConf.set("spark.sensitive.property", "has_secret_in_value") // Redact sensitive information val redactedConf = Utils.redact(sparkConf, sparkConf.getAll).toMap // Assert that secret information got redacted while the regular property remained the same secretKeys.foreach { key => assert(redactedConf(key) === Utils.REDACTION_REPLACEMENT_TEXT) } - assert(redactedConf("spark.regular.property") === "not_a_secret") + assert(redactedConf("spark.regular.property") === "regular_value") + assert(redactedConf("spark.sensitive.property") === Utils.REDACTION_REPLACEMENT_TEXT) + } } diff --git a/docs/configuration.md b/docs/configuration.md index 8b53e92ccd41..1d8d963016c7 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -372,8 +372,8 @@ Apart from these, the following properties are also available, and may be useful
    diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala index 99fb58a28934..59adb7e22d18 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala @@ -24,6 +24,7 @@ import java.util.{HashMap => JHashMap} import scala.collection.mutable import scala.concurrent.duration._ +import scala.io.Source import scala.language.postfixOps import com.google.common.io.{ByteStreams, Files} @@ -87,24 +88,30 @@ class YarnClusterSuite extends BaseYarnClusterSuite { testBasicYarnApp(false) } - test("run Spark in yarn-client mode with different configurations") { + test("run Spark in yarn-client mode with different configurations, ensuring redaction") { testBasicYarnApp(true, Map( "spark.driver.memory" -> "512m", "spark.executor.cores" -> "1", "spark.executor.memory" -> "512m", - "spark.executor.instances" -> "2" + "spark.executor.instances" -> "2", + // Sending some senstive information, which we'll make sure gets redacted + "spark.executorEnv.HADOOP_CREDSTORE_PASSWORD" -> YarnClusterDriver.SECRET_PASSWORD, + "spark.yarn.appMasterEnv.HADOOP_CREDSTORE_PASSWORD" -> YarnClusterDriver.SECRET_PASSWORD )) } - test("run Spark in yarn-cluster mode with different configurations") { + test("run Spark in yarn-cluster mode with different configurations, ensuring redaction") { testBasicYarnApp(false, Map( "spark.driver.memory" -> "512m", "spark.driver.cores" -> "1", "spark.executor.cores" -> "1", "spark.executor.memory" -> "512m", - "spark.executor.instances" -> "2" + "spark.executor.instances" -> "2", + // Sending some senstive information, which we'll make sure gets redacted + "spark.executorEnv.HADOOP_CREDSTORE_PASSWORD" -> YarnClusterDriver.SECRET_PASSWORD, + "spark.yarn.appMasterEnv.HADOOP_CREDSTORE_PASSWORD" -> YarnClusterDriver.SECRET_PASSWORD )) } @@ -349,6 +356,7 @@ private object YarnClusterDriverUseSparkHadoopUtilConf extends Logging with Matc private object YarnClusterDriver extends Logging with Matchers { val WAIT_TIMEOUT_MILLIS = 10000 + val SECRET_PASSWORD = "secret_password" def main(args: Array[String]): Unit = { if (args.length != 1) { @@ -395,6 +403,13 @@ private object YarnClusterDriver extends Logging with Matchers { assert(executorInfos.nonEmpty) executorInfos.foreach { info => assert(info.logUrlMap.nonEmpty) + info.logUrlMap.values.foreach { url => + val log = Source.fromURL(url).mkString + assert( + !log.contains(SECRET_PASSWORD), + s"Executor logs contain sensitive info (${SECRET_PASSWORD}): \n${log} " + ) + } } // If we are running in yarn-cluster mode, verify that driver logs links and present and are @@ -406,8 +421,13 @@ private object YarnClusterDriver extends Logging with Matchers { assert(driverLogs.contains("stderr")) assert(driverLogs.contains("stdout")) val urlStr = driverLogs("stderr") - // Ensure that this is a valid URL, else this will throw an exception - new URL(urlStr) + driverLogs.foreach { kv => + val log = Source.fromURL(kv._2).mkString + assert( + !log.contains(SECRET_PASSWORD), + s"Driver logs contain sensitive info (${SECRET_PASSWORD}): \n${log} " + ) + } val containerId = YarnSparkHadoopUtil.get.getContainerId val user = Utils.getCurrentUserName() assert(urlStr.endsWith(s"/node/containerlogs/$containerId/$user/stderr?start=-4096")) From b4724db19a10387a803cd7beec14facf7ad1894a Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Wed, 26 Apr 2017 22:18:01 -0700 Subject: [PATCH 359/512] [SPARK-20425][SQL] Support a vertical display mode for Dataset.show ## What changes were proposed in this pull request? This pr added a new display mode for `Dataset.show` to print output rows vertically (one line per column value). In the current master, when printing Dataset with many columns, the readability is low like; ``` scala> val df = spark.range(100).selectExpr((0 until 100).map(i => s"rand() AS c$i"): _*) scala> df.show(3, 0) +------------------+------------------+------------------+-------------------+------------------+------------------+-------------------+------------------+------------------+------------------+------------------+-------------------+------------------+------------------+------------------+-------------------+-------------------+-------------------+------------------+------------------+-------------------+------------------+-------------------+------------------+-------------------+-------------------+-------------------+--------------------+-------------------+------------------+-------------------+--------------------+------------------+------------------+-------------------+-------------------+-------------------+------------------+------------------+-------------------+------------------+------------------+-------------------+-------------------+-------------------+------------------+--------------------+--------------------+-------------------+-------------------+-------------------+-------------------+-------------------+-------------------+--------------------+-------------------+-------------------+-------------------+-------------------+------------------+------------------+-------------------+-------------------+------------------+-------------------+------------------+------------------+-----------------+-------------------+-------------------+------------------+-------------------+------------------+-------------------+-------------------+-------------------+------------------+-------------------+------------------+-------------------+-------------------+-------------------+-------------------+-------------------+-------------------+-------------------+-------------------+------------------+-------------------+-------------------+------------------+------------------+------------------+-------------------+------------------+-------------------+------------------+-------------------+-------------------+-------------------+ |c0 |c1 |c2 |c3 |c4 |c5 |c6 |c7 |c8 |c9 |c10 |c11 |c12 |c13 |c14 |c15 |c16 |c17 |c18 |c19 |c20 |c21 |c22 |c23 |c24 |c25 |c26 |c27 |c28 |c29 |c30 |c31 |c32 |c33 |c34 |c35 |c36 |c37 |c38 |c39 |c40 |c41 |c42 |c43 |c44 |c45 |c46 |c47 |c48 |c49 |c50 |c51 |c52 |c53 |c54 |c55 |c56 |c57 |c58 |c59 |c60 |c61 |c62 |c63 |c64 |c65 |c66 |c67 |c68 |c69 |c70 |c71 |c72 |c73 |c74 |c75 |c76 |c77 |c78 |c79 |c80 |c81 |c82 |c83 |c84 |c85 |c86 |c87 |c88 |c89 |c90 |c91 |c92 |c93 |c94 |c95 |c96 |c97 |c98 |c99 | +------------------+------------------+------------------+-------------------+------------------+------------------+-------------------+------------------+------------------+------------------+------------------+-------------------+------------------+------------------+------------------+-------------------+-------------------+-------------------+------------------+------------------+-------------------+------------------+-------------------+------------------+-------------------+-------------------+-------------------+--------------------+-------------------+------------------+-------------------+--------------------+------------------+------------------+-------------------+-------------------+-------------------+------------------+------------------+-------------------+------------------+------------------+-------------------+-------------------+-------------------+------------------+--------------------+--------------------+-------------------+-------------------+-------------------+-------------------+-------------------+-------------------+--------------------+-------------------+-------------------+-------------------+-------------------+------------------+------------------+-------------------+-------------------+------------------+-------------------+------------------+------------------+-----------------+-------------------+-------------------+------------------+-------------------+------------------+-------------------+-------------------+-------------------+------------------+-------------------+------------------+-------------------+-------------------+-------------------+-------------------+-------------------+-------------------+-------------------+-------------------+------------------+-------------------+-------------------+------------------+------------------+------------------+-------------------+------------------+-------------------+------------------+-------------------+-------------------+-------------------+ |0.6306087152476858|0.9174349686288383|0.5511324165035159|0.3320844128641819 |0.7738486877101489|0.2154915886962553|0.4754997600674299 |0.922780639280355 |0.7136894772661909|0.2277580838165979|0.5926874459847249|0.40311408392226633|0.467830264333843 |0.8330466896984213|0.1893258482389527|0.6320849515511165 |0.7530911056912044 |0.06700254871955424|0.370528597355559 |0.2755437445193154|0.23704391110980128|0.8067400174905822|0.13597793616251852|0.1708888820162453|0.01672725007605702|0.983118121881555 |0.25040195628629924|0.060537253723083384|0.20000530582637488|0.3400572407133511|0.9375689433322597 |0.057039316954370256|0.8053269714347623|0.5247817572228813|0.28419308820527944|0.9798908885194533 |0.31805988175678146|0.7034448027077574|0.5400575751346084|0.25336322371116216|0.9361634546853429|0.6118681368289798|0.6295081549153907 |0.13417468943957422|0.41617137072255794|0.7267230869252035|0.023792726137561115|0.5776157058356362 |0.04884204913195467|0.26728716103441275|0.646680370807925 |0.9782712690657244 |0.16434031314818154|0.20985522381321275|0.24739842475440077 |0.26335189682977334|0.19604841662422068|0.10742950487300651|0.20283136488091502|0.3100312319723688|0.886959006630645 |0.25157102269776244|0.34428775168410786|0.3500506818575777|0.3781142441912052 |0.8560316444386715|0.4737104888956839|0.735903101602148|0.02236617130529006|0.8769074095835873 |0.2001426662503153|0.5534032319238532 |0.7289496620397098|0.41955191309992157|0.9337700133660436 |0.34059094378451005|0.6419144759403556|0.08167496930341167|0.9947099478497635|0.48010888605366586|0.22314796858167918|0.17786598882331306|0.7351521162297135 |0.5422057170020095 |0.9521927872726792 |0.7459825486368227 |0.40907708791990627|0.8903819313311575|0.7251413746923618 |0.2977174938745204 |0.9515209660203555|0.9375968604766713|0.5087851740042524|0.4255237544908751 |0.8023768698664653|0.48003189618006703|0.1775841829745185|0.09050775629268382|0.6743909291138167 |0.2498415755876865 | |0.6866473844170801|0.4774360641212433|0.631696201340726 |0.33979113021468343|0.5663049010847052|0.7280190472258865|0.41370958502324806|0.9977433873622218|0.7671957338989901|0.2788708556233931|0.3355106391656496|0.88478952319287 |0.0333974166999893|0.6061744715862606|0.9617779139652359|0.22484954822341863|0.12770906021550898|0.5577789629508672 |0.2877649024640704|0.5566577406549361|0.9334933255278052 |0.9166720585157266|0.9689249324600591 |0.6367502457478598|0.7993572745928459 |0.23213222324218108|0.11928284054154137|0.6173493362456599 |0.0505122058694798 |0.9050228629552983|0.17112767911121707|0.47395598348370005 |0.5820498657823081|0.6241124650645072|0.18587258258036776|0.14987593554122225|0.3079446253653946 |0.9414228822867968|0.8362276265462365|0.9155655305576353 |0.5121559807153562|0.8963362656525707|0.22765970274318037|0.8177039187132797 |0.8190326635933787 |0.5256005177032199|0.8167598457269669 |0.030936807130934496|0.6733006585281015 |0.4208049626816347 |0.24603085738518538|0.22719198954208153|0.1622280557565281 |0.22217325159218038|0.014684419513742553|0.08987111517447499|0.2157764759142622 |0.8223414104088321 |0.4868624404491777 |0.4016191733088167|0.6169281906889263|0.15603611040433385|0.18289285085714913|0.9538408988218972|0.15037154865295121|0.5364516961987454|0.8077254873163031|0.712600478545675|0.7277477241003857 |0.19822912960348305|0.8305051199208777|0.18631911396566114|0.8909532487898342|0.3470409226992506 |0.35306974180587636|0.9107058868891469 |0.3321327206004986|0.48952332459050607|0.3630403307479373|0.5400046826340376 |0.5387377194310529 |0.42860539421837585|0.23214101630985995|0.21438968839794847|0.15370603160082352|0.04355605642700022|0.6096006707067466 |0.6933354157094292|0.06302172470859002|0.03174631856164001|0.664243581650643 |0.7833239547446621|0.696884598352864 |0.34626385933237736|0.9263495598791336|0.404818892816584 |0.2085585394755507|0.6150004897990109 |0.05391193524302473|0.28188484028329097| +------------------+------------------+------------------+-------------------+------------------+------------------+-------------------+------------------+------------------+------------------+------------------+-------------------+------------------+------------------+------------------+-------------------+-------------------+-------------------+------------------+------------------+-------------------+------------------+-------------------+------------------+-------------------+-------------------+-------------------+--------------------+-------------------+------------------+-------------------+--------------------+------------------+------------------+-------------------+-------------------+-------------------+------------------+------------------+-------------------+------------------+------------------+-------------------+-------------------+-------------------+------------------+--------------------+--------------------+-------------------+-------------------+-------------------+-------------------+-------------------+-------------------+--------------------+-------------------+-------------------+-------------------+-------------------+------------------+------------------+-------------------+-------------------+------------------+-------------------+------------------+------------------+-----------------+-------------------+-------------------+------------------+-------------------+------------------+-------------------+-------------------+-------------------+------------------+-------------------+------------------+-------------------+-------------------+-------------------+-------------------+-------------------+-------------------+-------------------+-------------------+------------------+-------------------+-------------------+------------------+------------------+------------------+-------------------+------------------+-------------------+------------------+-------------------+-------------------+-------------------+ only showing top 2 rows ``` `psql`, CLI for PostgreSQL, supports a vertical display mode for this case like: http://stackoverflow.com/questions/9604723/alternate-output-format-for-psql ``` -RECORD 0------------------- c0 | 0.6306087152476858 c1 | 0.9174349686288383 c2 | 0.5511324165035159 ... c98 | 0.05391193524302473 c99 | 0.28188484028329097 -RECORD 1------------------- c0 | 0.6866473844170801 c1 | 0.4774360641212433 c2 | 0.631696201340726 ... c98 | 0.05391193524302473 c99 | 0.28188484028329097 only showing top 2 rows ``` ## How was this patch tested? Added tests in `DataFrameSuite`. Author: Takeshi Yamamuro Closes #17733 from maropu/SPARK-20425. --- R/pkg/R/DataFrame.R | 8 +- python/pyspark/sql/dataframe.py | 15 +- .../scala/org/apache/spark/sql/Dataset.scala | 149 ++++++++++++++---- .../org/apache/spark/sql/DataFrameSuite.scala | 112 +++++++++++++ 4 files changed, 247 insertions(+), 37 deletions(-) diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index cd6f03a13d7c..7e57ba6287bb 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -194,6 +194,7 @@ setMethod("isLocal", #' 20 characters will be truncated. However, if set greater than zero, #' truncates strings longer than \code{truncate} characters and all cells #' will be aligned right. +#' @param vertical whether print output rows vertically (one line per column value). #' @param ... further arguments to be passed to or from other methods. #' @family SparkDataFrame functions #' @aliases showDF,SparkDataFrame-method @@ -210,12 +211,13 @@ setMethod("isLocal", #' @note showDF since 1.4.0 setMethod("showDF", signature(x = "SparkDataFrame"), - function(x, numRows = 20, truncate = TRUE) { + function(x, numRows = 20, truncate = TRUE, vertical = FALSE) { if (is.logical(truncate) && truncate) { - s <- callJMethod(x@sdf, "showString", numToInt(numRows), numToInt(20)) + s <- callJMethod(x@sdf, "showString", numToInt(numRows), numToInt(20), vertical) } else { truncate2 <- as.numeric(truncate) - s <- callJMethod(x@sdf, "showString", numToInt(numRows), numToInt(truncate2)) + s <- callJMethod(x@sdf, "showString", numToInt(numRows), numToInt(truncate2), + vertical) } cat(s) }) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 774caf53f3a4..ff21bb5d2fb3 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -290,13 +290,15 @@ def isStreaming(self): return self._jdf.isStreaming() @since(1.3) - def show(self, n=20, truncate=True): + def show(self, n=20, truncate=True, vertical=False): """Prints the first ``n`` rows to the console. :param n: Number of rows to show. :param truncate: If set to True, truncate strings longer than 20 chars by default. If set to a number greater than one, truncates long strings to length ``truncate`` and align cells right. + :param vertical: If set to True, print output rows vertically (one line + per column value). >>> df DataFrame[age: int, name: string] @@ -314,11 +316,18 @@ def show(self, n=20, truncate=True): | 2| Ali| | 5| Bob| +---+----+ + >>> df.show(vertical=True) + -RECORD 0----- + age | 2 + name | Alice + -RECORD 1----- + age | 5 + name | Bob """ if isinstance(truncate, bool) and truncate: - print(self._jdf.showString(n, 20)) + print(self._jdf.showString(n, 20, vertical)) else: - print(self._jdf.showString(n, int(truncate))) + print(self._jdf.showString(n, int(truncate), vertical)) def __repr__(self): return "DataFrame[%s]" % (", ".join("%s: %s" % c for c in self.dtypes)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 06dd5500718d..147e7651ce55 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -240,8 +240,10 @@ class Dataset[T] private[sql]( * @param _numRows Number of rows to show * @param truncate If set to more than 0, truncates strings to `truncate` characters and * all cells will be aligned right. + * @param vertical If set to true, prints output rows vertically (one line per column value). */ - private[sql] def showString(_numRows: Int, truncate: Int = 20): String = { + private[sql] def showString( + _numRows: Int, truncate: Int = 20, vertical: Boolean = false): String = { val numRows = _numRows.max(0) val takeResult = toDF().take(numRows + 1) val hasMoreData = takeResult.length > numRows @@ -277,46 +279,80 @@ class Dataset[T] private[sql]( val sb = new StringBuilder val numCols = schema.fieldNames.length + // We set a minimum column width at '3' + val minimumColWidth = 3 - // Initialise the width of each column to a minimum value of '3' - val colWidths = Array.fill(numCols)(3) + if (!vertical) { + // Initialise the width of each column to a minimum value + val colWidths = Array.fill(numCols)(minimumColWidth) - // Compute the width of each column - for (row <- rows) { - for ((cell, i) <- row.zipWithIndex) { - colWidths(i) = math.max(colWidths(i), cell.length) - } - } - - // Create SeparateLine - val sep: String = colWidths.map("-" * _).addString(sb, "+", "+", "+\n").toString() - - // column names - rows.head.zipWithIndex.map { case (cell, i) => - if (truncate > 0) { - StringUtils.leftPad(cell, colWidths(i)) - } else { - StringUtils.rightPad(cell, colWidths(i)) + // Compute the width of each column + for (row <- rows) { + for ((cell, i) <- row.zipWithIndex) { + colWidths(i) = math.max(colWidths(i), cell.length) + } } - }.addString(sb, "|", "|", "|\n") - sb.append(sep) + // Create SeparateLine + val sep: String = colWidths.map("-" * _).addString(sb, "+", "+", "+\n").toString() - // data - rows.tail.map { - _.zipWithIndex.map { case (cell, i) => + // column names + rows.head.zipWithIndex.map { case (cell, i) => if (truncate > 0) { - StringUtils.leftPad(cell.toString, colWidths(i)) + StringUtils.leftPad(cell, colWidths(i)) } else { - StringUtils.rightPad(cell.toString, colWidths(i)) + StringUtils.rightPad(cell, colWidths(i)) } }.addString(sb, "|", "|", "|\n") - } - sb.append(sep) + sb.append(sep) + + // data + rows.tail.foreach { + _.zipWithIndex.map { case (cell, i) => + if (truncate > 0) { + StringUtils.leftPad(cell.toString, colWidths(i)) + } else { + StringUtils.rightPad(cell.toString, colWidths(i)) + } + }.addString(sb, "|", "|", "|\n") + } + + sb.append(sep) + } else { + // Extended display mode enabled + val fieldNames = rows.head + val dataRows = rows.tail + + // Compute the width of field name and data columns + val fieldNameColWidth = fieldNames.foldLeft(minimumColWidth) { case (curMax, fieldName) => + math.max(curMax, fieldName.length) + } + val dataColWidth = dataRows.foldLeft(minimumColWidth) { case (curMax, row) => + math.max(curMax, row.map(_.length).reduceLeftOption[Int] { case (cellMax, cell) => + math.max(cellMax, cell) + }.getOrElse(0)) + } + + dataRows.zipWithIndex.foreach { case (row, i) => + // "+ 5" in size means a character length except for padded names and data + val rowHeader = StringUtils.rightPad( + s"-RECORD $i", fieldNameColWidth + dataColWidth + 5, "-") + sb.append(rowHeader).append("\n") + row.zipWithIndex.map { case (cell, j) => + val fieldName = StringUtils.rightPad(fieldNames(j), fieldNameColWidth) + val data = StringUtils.rightPad(cell, dataColWidth) + s" $fieldName | $data " + }.addString(sb, "", "\n", "\n") + } + } - // For Data that has more than "numRows" records - if (hasMoreData) { + // Print a footer + if (vertical && data.isEmpty) { + // In a vertical mode, print an empty row set explicitly + sb.append("(0 rows)\n") + } else if (hasMoreData) { + // For Data that has more than "numRows" records val rowsString = if (numRows == 1) "row" else "rows" sb.append(s"only showing top $numRows $rowsString\n") } @@ -663,8 +699,59 @@ class Dataset[T] private[sql]( * @group action * @since 1.6.0 */ + def show(numRows: Int, truncate: Int): Unit = show(numRows, truncate, vertical = false) + + /** + * Displays the Dataset in a tabular form. For example: + * {{{ + * year month AVG('Adj Close) MAX('Adj Close) + * 1980 12 0.503218 0.595103 + * 1981 01 0.523289 0.570307 + * 1982 02 0.436504 0.475256 + * 1983 03 0.410516 0.442194 + * 1984 04 0.450090 0.483521 + * }}} + * + * If `vertical` enabled, this command prints output rows vertically (one line per column value)? + * + * {{{ + * -RECORD 0------------------- + * year | 1980 + * month | 12 + * AVG('Adj Close) | 0.503218 + * AVG('Adj Close) | 0.595103 + * -RECORD 1------------------- + * year | 1981 + * month | 01 + * AVG('Adj Close) | 0.523289 + * AVG('Adj Close) | 0.570307 + * -RECORD 2------------------- + * year | 1982 + * month | 02 + * AVG('Adj Close) | 0.436504 + * AVG('Adj Close) | 0.475256 + * -RECORD 3------------------- + * year | 1983 + * month | 03 + * AVG('Adj Close) | 0.410516 + * AVG('Adj Close) | 0.442194 + * -RECORD 4------------------- + * year | 1984 + * month | 04 + * AVG('Adj Close) | 0.450090 + * AVG('Adj Close) | 0.483521 + * }}} + * + * @param numRows Number of rows to show + * @param truncate If set to more than 0, truncates strings to `truncate` characters and + * all cells will be aligned right. + * @param vertical If set to true, prints output rows vertically (one line per column value). + * @group action + * @since 2.3.0 + */ // scalastyle:off println - def show(numRows: Int, truncate: Int): Unit = println(showString(numRows, truncate)) + def show(numRows: Int, truncate: Int, vertical: Boolean): Unit = + println(showString(numRows, truncate, vertical)) // scalastyle:on println /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index b4893b56a8a8..ef0de6f6f4ff 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -764,6 +764,21 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { assert(df.showString(10, truncate = 20) === expectedAnswerForTrue) } + test("showString: truncate = [0, 20], vertical = true") { + val longString = Array.fill(21)("1").mkString + val df = sparkContext.parallelize(Seq("1", longString)).toDF() + val expectedAnswerForFalse = "-RECORD 0----------------------\n" + + " value | 1 \n" + + "-RECORD 1----------------------\n" + + " value | 111111111111111111111 \n" + assert(df.showString(10, truncate = 0, vertical = true) === expectedAnswerForFalse) + val expectedAnswerForTrue = "-RECORD 0---------------------\n" + + " value | 1 \n" + + "-RECORD 1---------------------\n" + + " value | 11111111111111111... \n" + assert(df.showString(10, truncate = 20, vertical = true) === expectedAnswerForTrue) + } + test("showString: truncate = [3, 17]") { val longString = Array.fill(21)("1").mkString val df = sparkContext.parallelize(Seq("1", longString)).toDF() @@ -785,6 +800,21 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { assert(df.showString(10, truncate = 17) === expectedAnswerForTrue) } + test("showString: truncate = [3, 17], vertical = true") { + val longString = Array.fill(21)("1").mkString + val df = sparkContext.parallelize(Seq("1", longString)).toDF() + val expectedAnswerForFalse = "-RECORD 0----\n" + + " value | 1 \n" + + "-RECORD 1----\n" + + " value | 111 \n" + assert(df.showString(10, truncate = 3, vertical = true) === expectedAnswerForFalse) + val expectedAnswerForTrue = "-RECORD 0------------------\n" + + " value | 1 \n" + + "-RECORD 1------------------\n" + + " value | 11111111111111... \n" + assert(df.showString(10, truncate = 17, vertical = true) === expectedAnswerForTrue) + } + test("showString(negative)") { val expectedAnswer = """+---+-----+ ||key|value| @@ -795,6 +825,11 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { assert(testData.select($"*").showString(-1) === expectedAnswer) } + test("showString(negative), vertical = true") { + val expectedAnswer = "(0 rows)\n" + assert(testData.select($"*").showString(-1, vertical = true) === expectedAnswer) + } + test("showString(0)") { val expectedAnswer = """+---+-----+ ||key|value| @@ -805,6 +840,11 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { assert(testData.select($"*").showString(0) === expectedAnswer) } + test("showString(0), vertical = true") { + val expectedAnswer = "(0 rows)\n" + assert(testData.select($"*").showString(0, vertical = true) === expectedAnswer) + } + test("showString: array") { val df = Seq( (Array(1, 2, 3), Array(1, 2, 3)), @@ -820,6 +860,20 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { assert(df.showString(10) === expectedAnswer) } + test("showString: array, vertical = true") { + val df = Seq( + (Array(1, 2, 3), Array(1, 2, 3)), + (Array(2, 3, 4), Array(2, 3, 4)) + ).toDF() + val expectedAnswer = "-RECORD 0--------\n" + + " _1 | [1, 2, 3] \n" + + " _2 | [1, 2, 3] \n" + + "-RECORD 1--------\n" + + " _1 | [2, 3, 4] \n" + + " _2 | [2, 3, 4] \n" + assert(df.showString(10, vertical = true) === expectedAnswer) + } + test("showString: binary") { val df = Seq( ("12".getBytes(StandardCharsets.UTF_8), "ABC.".getBytes(StandardCharsets.UTF_8)), @@ -835,6 +889,20 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { assert(df.showString(10) === expectedAnswer) } + test("showString: binary, vertical = true") { + val df = Seq( + ("12".getBytes(StandardCharsets.UTF_8), "ABC.".getBytes(StandardCharsets.UTF_8)), + ("34".getBytes(StandardCharsets.UTF_8), "12346".getBytes(StandardCharsets.UTF_8)) + ).toDF() + val expectedAnswer = "-RECORD 0---------------\n" + + " _1 | [31 32] \n" + + " _2 | [41 42 43 2E] \n" + + "-RECORD 1---------------\n" + + " _1 | [33 34] \n" + + " _2 | [31 32 33 34 36] \n" + assert(df.showString(10, vertical = true) === expectedAnswer) + } + test("showString: minimum column width") { val df = Seq( (1, 1), @@ -850,6 +918,20 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { assert(df.showString(10) === expectedAnswer) } + test("showString: minimum column width, vertical = true") { + val df = Seq( + (1, 1), + (2, 2) + ).toDF() + val expectedAnswer = "-RECORD 0--\n" + + " _1 | 1 \n" + + " _2 | 1 \n" + + "-RECORD 1--\n" + + " _1 | 2 \n" + + " _2 | 2 \n" + assert(df.showString(10, vertical = true) === expectedAnswer) + } + test("SPARK-7319 showString") { val expectedAnswer = """+---+-----+ ||key|value| @@ -861,6 +943,14 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { assert(testData.select($"*").showString(1) === expectedAnswer) } + test("SPARK-7319 showString, vertical = true") { + val expectedAnswer = "-RECORD 0----\n" + + " key | 1 \n" + + " value | 1 \n" + + "only showing top 1 row\n" + assert(testData.select($"*").showString(1, vertical = true) === expectedAnswer) + } + test("SPARK-7327 show with empty dataFrame") { val expectedAnswer = """+---+-----+ ||key|value| @@ -870,6 +960,10 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { assert(testData.select($"*").filter($"key" < 0).showString(1) === expectedAnswer) } + test("SPARK-7327 show with empty dataFrame, vertical = true") { + assert(testData.select($"*").filter($"key" < 0).showString(1, vertical = true) === "(0 rows)\n") + } + test("SPARK-18350 show with session local timezone") { val d = Date.valueOf("2016-12-01") val ts = Timestamp.valueOf("2016-12-01 00:00:00") @@ -894,6 +988,24 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } } + test("SPARK-18350 show with session local timezone, vertical = true") { + val d = Date.valueOf("2016-12-01") + val ts = Timestamp.valueOf("2016-12-01 00:00:00") + val df = Seq((d, ts)).toDF("d", "ts") + val expectedAnswer = "-RECORD 0------------------\n" + + " d | 2016-12-01 \n" + + " ts | 2016-12-01 00:00:00 \n" + assert(df.showString(1, truncate = 0, vertical = true) === expectedAnswer) + + withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> "GMT") { + + val expectedAnswer = "-RECORD 0------------------\n" + + " d | 2016-12-01 \n" + + " ts | 2016-12-01 08:00:00 \n" + assert(df.showString(1, truncate = 0, vertical = true) === expectedAnswer) + } + } + test("createDataFrame(RDD[Row], StructType) should convert UDTs (SPARK-6672)") { val rowRDD = sparkContext.parallelize(Seq(Row(new ExamplePoint(1.0, 2.0)))) val schema = StructType(Array(StructField("point", new ExamplePointUDT(), false))) From b58cf77c4db49ba236b779905a943f025c6aaedd Mon Sep 17 00:00:00 2001 From: zero323 Date: Thu, 27 Apr 2017 00:29:43 -0700 Subject: [PATCH 360/512] [DOCS][MINOR] Add missing since to SparkR repeat_string note. ## What changes were proposed in this pull request? Replace note repeat_string 2.3.0 with note repeat_string since 2.3.0 ## How was this patch tested? `create-docs.sh` Author: zero323 Closes #17779 from zero323/REPEAT-NOTE. --- R/pkg/R/functions.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 752e4c5c7189..6b91fa5bde67 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -3796,7 +3796,7 @@ setMethod("split_string", #' # This is equivalent to the following SQL expression #' first(selectExpr(df, "repeat(value, 3)")) #' } -#' @note repeat_string 2.3.0 +#' @note repeat_string since 2.3.0 setMethod("repeat_string", signature(x = "Column", n = "numeric"), function(x, n) { From ba7666274e71f1903e5050a5e53fbdcd21debde5 Mon Sep 17 00:00:00 2001 From: zero323 Date: Thu, 27 Apr 2017 00:34:20 -0700 Subject: [PATCH 361/512] [SPARK-20208][DOCS][FOLLOW-UP] Add FP-Growth to SparkR programming guide ## What changes were proposed in this pull request? Add `spark.fpGrowth` to SparkR programming guide. ## How was this patch tested? Manual tests. Author: zero323 Closes #17775 from zero323/SPARK-20208-FOLLOW-UP. --- docs/sparkr.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/docs/sparkr.md b/docs/sparkr.md index e015ab260fca..c3336ac2ce86 100644 --- a/docs/sparkr.md +++ b/docs/sparkr.md @@ -504,6 +504,10 @@ SparkR supports the following machine learning algorithms currently: * [`spark.als`](api/R/spark.als.html): [`Alternating Least Squares (ALS)`](ml-collaborative-filtering.html#collaborative-filtering) +#### Frequent Pattern Mining + +* [`spark.fpGrowth`](api/R/spark.fpGrowth.html) : [`FP-growth`](ml-frequent-pattern-mining.html#fp-growth) + #### Statistics * [`spark.kstest`](api/R/spark.kstest.html): `Kolmogorov-Smirnov Test` From 7633933e54ffb08ab9d959be5f76c26fae29d1d9 Mon Sep 17 00:00:00 2001 From: Davis Shepherd Date: Thu, 27 Apr 2017 18:06:12 +0000 Subject: [PATCH 362/512] [SPARK-20483] Mesos Coarse mode may starve other Mesos frameworks ## What changes were proposed in this pull request? Set maxCores to be a multiple of the smallest executor that can be launched. This ensures that we correctly detect the condition where no more executors will be launched when spark.cores.max is not a multiple of spark.executor.cores ## How was this patch tested? This was manually tested with other sample frameworks measuring their incoming offers to determine if starvation would occur. dbtsai mgummelt Author: Davis Shepherd Closes #17786 from dgshep/fix_mesos_max_cores. --- .../MesosCoarseGrainedSchedulerBackend.scala | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala index 2a36ec4fa811..8f5b97ccb1f8 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala @@ -60,8 +60,16 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( private val maxCoresOption = conf.getOption("spark.cores.max").map(_.toInt) + private val executorCoresOption = conf.getOption("spark.executor.cores").map(_.toInt) + + private val minCoresPerExecutor = executorCoresOption.getOrElse(1) + // Maximum number of cores to acquire - private val maxCores = maxCoresOption.getOrElse(Int.MaxValue) + private val maxCores = { + val cores = maxCoresOption.getOrElse(Int.MaxValue) + // Set maxCores to a multiple of smallest executor we can launch + cores - (cores % minCoresPerExecutor) + } private val useFetcherCache = conf.getBoolean("spark.mesos.fetcherCache.enable", false) @@ -489,8 +497,9 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( } private def executorCores(offerCPUs: Int): Int = { - sc.conf.getInt("spark.executor.cores", - math.min(offerCPUs, maxCores - totalCoresAcquired)) + executorCoresOption.getOrElse( + math.min(offerCPUs, maxCores - totalCoresAcquired) + ) } override def statusUpdate(d: org.apache.mesos.SchedulerDriver, status: TaskStatus) { From 561e9cc390b429e4252f59f00a7ca4f6f8c853f8 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Thu, 27 Apr 2017 11:31:01 -0700 Subject: [PATCH 363/512] [SPARK-20421][CORE] Mark internal listeners as deprecated. These listeners weren't really meant for external consumption, but they're public and marked with DeveloperApi. Adding the deprecated tag warns people that they may soon go away (as they will as part of the work for SPARK-18085). Note that not all types made public by https://github.com/apache/spark/pull/648 are being deprecated. Some remaining types are still exposed through the SparkListener API. Also note the text for StorageStatus is a tiny bit different, since I'm not so sure I'll be able to remove it. But the effect for the users should be the same (they should stop trying to use it). Author: Marcelo Vanzin Closes #17766 from vanzin/SPARK-20421. --- .../scala/org/apache/spark/storage/StorageStatusListener.scala | 1 + core/src/main/scala/org/apache/spark/storage/StorageUtils.scala | 1 + core/src/main/scala/org/apache/spark/ui/env/EnvironmentTab.scala | 1 + core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala | 1 + .../scala/org/apache/spark/ui/jobs/JobProgressListener.scala | 1 + core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala | 1 + 6 files changed, 6 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala b/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala index 1b30d4fa93bc..ac60f795915a 100644 --- a/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala +++ b/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala @@ -30,6 +30,7 @@ import org.apache.spark.scheduler._ * This class is thread-safe (unlike JobProgressListener) */ @DeveloperApi +@deprecated("This class will be removed in a future release.", "2.2.0") class StorageStatusListener(conf: SparkConf) extends SparkListener { // This maintains only blocks that are cached (i.e. storage level is not StorageLevel.NONE) private[storage] val executorIdToStorageStatus = mutable.Map[String, StorageStatus]() diff --git a/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala b/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala index 8f0d181fc8fe..e9694fdbca2d 100644 --- a/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala +++ b/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala @@ -35,6 +35,7 @@ import org.apache.spark.internal.Logging * class cannot mutate the source of the information. Accesses are not thread-safe. */ @DeveloperApi +@deprecated("This class may be removed or made private in a future release.", "2.2.0") class StorageStatus( val blockManagerId: BlockManagerId, val maxMemory: Long, diff --git a/core/src/main/scala/org/apache/spark/ui/env/EnvironmentTab.scala b/core/src/main/scala/org/apache/spark/ui/env/EnvironmentTab.scala index 70b3ffd95e60..8c18464e6477 100644 --- a/core/src/main/scala/org/apache/spark/ui/env/EnvironmentTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/env/EnvironmentTab.scala @@ -32,6 +32,7 @@ private[ui] class EnvironmentTab(parent: SparkUI) extends SparkUITab(parent, "en * A SparkListener that prepares information to be displayed on the EnvironmentTab */ @DeveloperApi +@deprecated("This class will be removed in a future release.", "2.2.0") class EnvironmentListener extends SparkListener { var jvmInformation = Seq[(String, String)]() var sparkProperties = Seq[(String, String)]() diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala index 03851293eb2f..aabf6e0c63c0 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala @@ -62,6 +62,7 @@ private[ui] case class ExecutorTaskSummary( * A SparkListener that prepares information to be displayed on the ExecutorsTab */ @DeveloperApi +@deprecated("This class will be removed in a future release.", "2.2.0") class ExecutorsListener(storageStatusListener: StorageStatusListener, conf: SparkConf) extends SparkListener { val executorToTaskSummary = LinkedHashMap[String, ExecutorTaskSummary]() diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala index f78db5ab80d1..8870187f2219 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala @@ -41,6 +41,7 @@ import org.apache.spark.ui.jobs.UIData._ * updating the internal data structures concurrently. */ @DeveloperApi +@deprecated("This class will be removed in a future release.", "2.2.0") class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { // Define a handful of type aliases so that data structures' types can serve as documentation. diff --git a/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala b/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala index c212362557be..148efb134e14 100644 --- a/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala @@ -39,6 +39,7 @@ private[ui] class StorageTab(parent: SparkUI) extends SparkUITab(parent, "storag * This class is thread-safe (unlike JobProgressListener) */ @DeveloperApi +@deprecated("This class will be removed in a future release.", "2.2.0") class StorageListener(storageStatusListener: StorageStatusListener) extends BlockStatusListener { private[ui] val _rddInfoMap = mutable.Map[Int, RDDInfo]() // exposed for testing From 85c6ce61930490e2247fb4b0e22dfebbb8b6a1ee Mon Sep 17 00:00:00 2001 From: jinxing Date: Thu, 27 Apr 2017 14:06:07 -0500 Subject: [PATCH 364/512] [SPARK-20426] Lazy initialization of FileSegmentManagedBuffer for shuffle service. ## What changes were proposed in this pull request? When application contains large amount of shuffle blocks. NodeManager requires lots of memory to keep metadata(`FileSegmentManagedBuffer`) in `StreamManager`. When the number of shuffle blocks is big enough. NodeManager can run OOM. This pr proposes to do lazy initialization of `FileSegmentManagedBuffer` in shuffle service. ## How was this patch tested? Manually test. Author: jinxing Closes #17744 from jinxing64/SPARK-20426. --- .../shuffle/ExternalShuffleBlockHandler.java | 31 ++++++++++++------- .../ExternalShuffleBlockHandlerSuite.java | 4 +-- .../ExternalShuffleIntegrationSuite.java | 5 ++- .../network/netty/NettyBlockRpcServer.scala | 9 +++--- 4 files changed, 29 insertions(+), 20 deletions(-) diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java index 6daf9609d76d..c0f1da50f5e6 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java @@ -21,7 +21,7 @@ import java.io.IOException; import java.nio.ByteBuffer; import java.util.HashMap; -import java.util.List; +import java.util.Iterator; import java.util.Map; import com.codahale.metrics.Gauge; @@ -30,7 +30,6 @@ import com.codahale.metrics.MetricSet; import com.codahale.metrics.Timer; import com.google.common.annotations.VisibleForTesting; -import com.google.common.collect.Lists; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -93,14 +92,25 @@ protected void handleMessage( OpenBlocks msg = (OpenBlocks) msgObj; checkAuth(client, msg.appId); - List blocks = Lists.newArrayList(); - long totalBlockSize = 0; - for (String blockId : msg.blockIds) { - final ManagedBuffer block = blockManager.getBlockData(msg.appId, msg.execId, blockId); - totalBlockSize += block != null ? block.size() : 0; - blocks.add(block); - } - long streamId = streamManager.registerStream(client.getClientId(), blocks.iterator()); + Iterator iter = new Iterator() { + private int index = 0; + + @Override + public boolean hasNext() { + return index < msg.blockIds.length; + } + + @Override + public ManagedBuffer next() { + final ManagedBuffer block = blockManager.getBlockData(msg.appId, msg.execId, + msg.blockIds[index]); + index++; + metrics.blockTransferRateBytes.mark(block != null ? block.size() : 0); + return block; + } + }; + + long streamId = streamManager.registerStream(client.getClientId(), iter); if (logger.isTraceEnabled()) { logger.trace("Registered streamId {} with {} buffers for client {} from host {}", streamId, @@ -109,7 +119,6 @@ protected void handleMessage( getRemoteAddress(client.getChannel())); } callback.onSuccess(new StreamHandle(streamId, msg.blockIds.length).toByteBuffer()); - metrics.blockTransferRateBytes.mark(totalBlockSize); } finally { responseDelayContext.stop(); } diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java index e47a72c9d16c..4d48b1897038 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java @@ -88,8 +88,6 @@ public void testOpenShuffleBlocks() { ByteBuffer openBlocks = new OpenBlocks("app0", "exec1", new String[] { "b0", "b1" }) .toByteBuffer(); handler.receive(client, openBlocks, callback); - verify(blockResolver, times(1)).getBlockData("app0", "exec1", "b0"); - verify(blockResolver, times(1)).getBlockData("app0", "exec1", "b1"); ArgumentCaptor response = ArgumentCaptor.forClass(ByteBuffer.class); verify(callback, times(1)).onSuccess(response.capture()); @@ -107,6 +105,8 @@ public void testOpenShuffleBlocks() { assertEquals(block0Marker, buffers.next()); assertEquals(block1Marker, buffers.next()); assertFalse(buffers.hasNext()); + verify(blockResolver, times(1)).getBlockData("app0", "exec1", "b0"); + verify(blockResolver, times(1)).getBlockData("app0", "exec1", "b1"); // Verify open block request latency metrics Timer openBlockRequestLatencyMillis = (Timer) ((ExternalShuffleBlockHandler) handler) diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java index b8ae04eefb97..7a33b6821792 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java @@ -216,9 +216,8 @@ public void testFetchWrongExecutor() throws Exception { registerExecutor("exec-0", dataContext0.createExecutorInfo(SORT_MANAGER)); FetchResult execFetch = fetchBlocks("exec-0", new String[] { "shuffle_0_0_0" /* right */, "shuffle_1_0_0" /* wrong */ }); - // Both still fail, as we start by checking for all block. - assertTrue(execFetch.successBlocks.isEmpty()); - assertEquals(Sets.newHashSet("shuffle_0_0_0", "shuffle_1_0_0"), execFetch.failedBlocks); + assertEquals(Sets.newHashSet("shuffle_0_0_0"), execFetch.successBlocks); + assertEquals(Sets.newHashSet("shuffle_1_0_0"), execFetch.failedBlocks); } @Test diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala index 2ed8a00df702..305fd9a6de10 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala @@ -56,11 +56,12 @@ class NettyBlockRpcServer( message match { case openBlocks: OpenBlocks => - val blocks: Seq[ManagedBuffer] = - openBlocks.blockIds.map(BlockId.apply).map(blockManager.getBlockData) + val blocksNum = openBlocks.blockIds.length + val blocks = for (i <- (0 until blocksNum).view) + yield blockManager.getBlockData(BlockId.apply(openBlocks.blockIds(i))) val streamId = streamManager.registerStream(appId, blocks.iterator.asJava) - logTrace(s"Registered streamId $streamId with ${blocks.size} buffers") - responseContext.onSuccess(new StreamHandle(streamId, blocks.size).toByteBuffer) + logTrace(s"Registered streamId $streamId with $blocksNum buffers") + responseContext.onSuccess(new StreamHandle(streamId, blocksNum).toByteBuffer) case uploadBlock: UploadBlock => // StorageLevel and ClassTag are serialized as bytes using our JavaSerializer. From 26ac2ce05cbaf8f152347219403e31491e9c9bf1 Mon Sep 17 00:00:00 2001 From: Kris Mok Date: Thu, 27 Apr 2017 12:08:16 -0700 Subject: [PATCH 365/512] [SPARK-20482][SQL] Resolving Casts is too strict on having time zone set ## What changes were proposed in this pull request? Relax the requirement that a `TimeZoneAwareExpression` has to have its `timeZoneId` set to be considered resolved. With this change, a `Cast` (which is a `TimeZoneAwareExpression`) can be considered resolved if the `(fromType, toType)` combination doesn't require time zone information. Also de-relaxed test cases in `CastSuite` so Casts in that test suite don't get a default`timeZoneId = Option("GMT")`. ## How was this patch tested? Ran the de-relaxed`CastSuite` and it's passing. Also ran the SQL unit tests and they're passing too. Author: Kris Mok Closes #17777 from rednaxelafx/fix-catalyst-cast-timezone. --- .../spark/sql/catalyst/expressions/Cast.scala | 32 +++++++++++++++++++ .../sql/catalyst/expressions/CastSuite.scala | 4 +-- 2 files changed, 34 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index bb1273f5c3d8..a53ef426f79b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -89,6 +89,31 @@ object Cast { case _ => false } + /** + * Return true if we need to use the `timeZone` information casting `from` type to `to` type. + * The patterns matched reflect the current implementation in the Cast node. + * c.f. usage of `timeZone` in: + * * Cast.castToString + * * Cast.castToDate + * * Cast.castToTimestamp + */ + def needsTimeZone(from: DataType, to: DataType): Boolean = (from, to) match { + case (StringType, TimestampType) => true + case (DateType, TimestampType) => true + case (TimestampType, StringType) => true + case (TimestampType, DateType) => true + case (ArrayType(fromType, _), ArrayType(toType, _)) => needsTimeZone(fromType, toType) + case (MapType(fromKey, fromValue, _), MapType(toKey, toValue, _)) => + needsTimeZone(fromKey, toKey) || needsTimeZone(fromValue, toValue) + case (StructType(fromFields), StructType(toFields)) => + fromFields.length == toFields.length && + fromFields.zip(toFields).exists { + case (fromField, toField) => + needsTimeZone(fromField.dataType, toField.dataType) + } + case _ => false + } + /** * Return true iff we may truncate during casting `from` type to `to` type. e.g. long -> int, * timestamp -> date. @@ -165,6 +190,13 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = copy(timeZoneId = Option(timeZoneId)) + // When this cast involves TimeZone, it's only resolved if the timeZoneId is set; + // Otherwise behave like Expression.resolved. + override lazy val resolved: Boolean = + childrenResolved && checkInputDataTypes().isSuccess && (!needsTimeZone || timeZoneId.isDefined) + + private[this] def needsTimeZone: Boolean = Cast.needsTimeZone(child.dataType, dataType) + // [[func]] assumes the input is no longer null because eval already does the null check. @inline private[this] def buildCast[T](a: Any, func: T => Any): Any = func(a.asInstanceOf[T]) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index 22f3f3514fa4..a7ffa884d228 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -34,7 +34,7 @@ import org.apache.spark.unsafe.types.UTF8String */ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { - private def cast(v: Any, targetType: DataType, timeZoneId: Option[String] = Some("GMT")): Cast = { + private def cast(v: Any, targetType: DataType, timeZoneId: Option[String] = None): Cast = { v match { case lit: Expression => Cast(lit, targetType, timeZoneId) case _ => Cast(Literal(v), targetType, timeZoneId) @@ -47,7 +47,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { } private def checkNullCast(from: DataType, to: DataType): Unit = { - checkEvaluation(cast(Literal.create(null, from), to), null) + checkEvaluation(cast(Literal.create(null, from), to, Option("GMT")), null) } test("null cast") { From a4aa4665a6775b514b714c88b70576090d2b4a7e Mon Sep 17 00:00:00 2001 From: Tejas Patil Date: Thu, 27 Apr 2017 12:13:16 -0700 Subject: [PATCH 366/512] [SPARK-20487][SQL] `HiveTableScan` node is quite verbose in explained plan ## What changes were proposed in this pull request? Changed `TreeNode.argString` to handle `CatalogTable` separately (otherwise it would call the default `toString` on the `CatalogTable`) ## How was this patch tested? - Expanded scope of existing unit test to ensure that verbose information is not present - Manual testing Before ``` scala> hc.sql(" SELECT * FROM my_table WHERE name = 'foo' ").explain(true) == Parsed Logical Plan == 'Project [*] +- 'Filter ('name = foo) +- 'UnresolvedRelation `my_table` == Analyzed Logical Plan == user_id: bigint, name: string, ds: string Project [user_id#13L, name#14, ds#15] +- Filter (name#14 = foo) +- SubqueryAlias my_table +- CatalogRelation CatalogTable( Database: default Table: my_table Owner: tejasp Created: Fri Apr 14 17:05:50 PDT 2017 Last Access: Wed Dec 31 16:00:00 PST 1969 Type: MANAGED Provider: hive Properties: [serialization.format=1] Statistics: 9223372036854775807 bytes Location: file:/tmp/warehouse/my_table Serde Library: org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe InputFormat: org.apache.hadoop.mapred.TextInputFormat OutputFormat: org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat Partition Provider: Catalog Partition Columns: [`ds`] Schema: root -- user_id: long (nullable = true) -- name: string (nullable = true) -- ds: string (nullable = true) ), [user_id#13L, name#14], [ds#15] == Optimized Logical Plan == Filter (isnotnull(name#14) && (name#14 = foo)) +- CatalogRelation CatalogTable( Database: default Table: my_table Owner: tejasp Created: Fri Apr 14 17:05:50 PDT 2017 Last Access: Wed Dec 31 16:00:00 PST 1969 Type: MANAGED Provider: hive Properties: [serialization.format=1] Statistics: 9223372036854775807 bytes Location: file:/tmp/warehouse/my_table Serde Library: org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe InputFormat: org.apache.hadoop.mapred.TextInputFormat OutputFormat: org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat Partition Provider: Catalog Partition Columns: [`ds`] Schema: root -- user_id: long (nullable = true) -- name: string (nullable = true) -- ds: string (nullable = true) ), [user_id#13L, name#14], [ds#15] == Physical Plan == *Filter (isnotnull(name#14) && (name#14 = foo)) +- HiveTableScan [user_id#13L, name#14, ds#15], CatalogRelation CatalogTable( Database: default Table: my_table Owner: tejasp Created: Fri Apr 14 17:05:50 PDT 2017 Last Access: Wed Dec 31 16:00:00 PST 1969 Type: MANAGED Provider: hive Properties: [serialization.format=1] Statistics: 9223372036854775807 bytes Location: file:/tmp/warehouse/my_table Serde Library: org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe InputFormat: org.apache.hadoop.mapred.TextInputFormat OutputFormat: org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat Partition Provider: Catalog Partition Columns: [`ds`] Schema: root -- user_id: long (nullable = true) -- name: string (nullable = true) -- ds: string (nullable = true) ), [user_id#13L, name#14], [ds#15] ``` After ``` scala> hc.sql(" SELECT * FROM my_table WHERE name = 'foo' ").explain(true) == Parsed Logical Plan == 'Project [*] +- 'Filter ('name = foo) +- 'UnresolvedRelation `my_table` == Analyzed Logical Plan == user_id: bigint, name: string, ds: string Project [user_id#13L, name#14, ds#15] +- Filter (name#14 = foo) +- SubqueryAlias my_table +- CatalogRelation `default`.`my_table`, [user_id#13L, name#14], [ds#15] == Optimized Logical Plan == Filter (isnotnull(name#14) && (name#14 = foo)) +- CatalogRelation `default`.`my_table`, [user_id#13L, name#14], [ds#15] == Physical Plan == *Filter (isnotnull(name#14) && (name#14 = foo)) +- HiveTableScan [user_id#13L, name#14, ds#15], CatalogRelation `default`.`my_table`, [user_id#13L, name#14], [ds#15] ``` Author: Tejas Patil Closes #17780 from tejasapatil/SPARK-20487_verbose_plan. --- .../spark/sql/catalyst/trees/TreeNode.scala | 1 + .../sql/hive/execution/HiveExplainSuite.scala | 18 +++++++++++++++++- 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index cc4c0835954b..b091315f24f1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -444,6 +444,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { case None => Nil case Some(null) => Nil case Some(any) => any :: Nil + case table: CatalogTable => table.identifier :: Nil case other => other :: Nil }.mkString(", ") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala index 8a37bc3665d3..ebafe6de0c83 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala @@ -47,7 +47,23 @@ class HiveExplainSuite extends QueryTest with SQLTestUtils with TestHiveSingleto checkKeywordsNotExist(sql(" explain select * from src where key=123 "), "== Parsed Logical Plan ==", "== Analyzed Logical Plan ==", - "== Optimized Logical Plan ==") + "== Optimized Logical Plan ==", + "Owner", + "Database", + "Created", + "Last Access", + "Type", + "Provider", + "Properties", + "Statistics", + "Location", + "Serde Library", + "InputFormat", + "OutputFormat", + "Partition Provider", + "Schema" + ) + checkKeywordsExist(sql(" explain extended select * from src where key=123 "), "== Parsed Logical Plan ==", "== Analyzed Logical Plan ==", From 039e32ca19d113e3be2c09171c7c921698be7ab8 Mon Sep 17 00:00:00 2001 From: Davis Shepherd Date: Thu, 27 Apr 2017 20:25:52 +0000 Subject: [PATCH 367/512] [SPARK-20483][MINOR] Test for Mesos Coarse mode may starve other Mesos frameworks ## What changes were proposed in this pull request? Add test case for scenarios where executor.cores is set as a (non)divisor of spark.cores.max This tests the change in #17786 ## How was this patch tested? Ran the existing test suite with the new tests dbtsai Author: Davis Shepherd Closes #17788 from dgshep/add_mesos_test. --- ...osCoarseGrainedSchedulerBackendSuite.scala | 34 +++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala index c040f05d93b3..0418bfbaa5ed 100644 --- a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala +++ b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala @@ -199,6 +199,40 @@ class MesosCoarseGrainedSchedulerBackendSuite extends SparkFunSuite verifyDeclinedOffer(driver, createOfferId("o2"), true) } + test("mesos declines offers with a filter when maxCores not a multiple of executor.cores") { + val maxCores = 4 + val executorCores = 3 + setBackend(Map( + "spark.cores.max" -> maxCores.toString, + "spark.executor.cores" -> executorCores.toString + )) + val executorMemory = backend.executorMemory(sc) + offerResources(List( + Resources(executorMemory, maxCores + 1), + Resources(executorMemory, maxCores + 1) + )) + verifyTaskLaunched(driver, "o1") + verifyDeclinedOffer(driver, createOfferId("o2"), true) + } + + test("mesos declines offers with a filter when reached spark.cores.max with executor.cores") { + val maxCores = 4 + val executorCores = 2 + setBackend(Map( + "spark.cores.max" -> maxCores.toString, + "spark.executor.cores" -> executorCores.toString + )) + val executorMemory = backend.executorMemory(sc) + offerResources(List( + Resources(executorMemory, maxCores + 1), + Resources(executorMemory, maxCores + 1), + Resources(executorMemory, maxCores + 1) + )) + verifyTaskLaunched(driver, "o1") + verifyTaskLaunched(driver, "o2") + verifyDeclinedOffer(driver, createOfferId("o3"), true) + } + test("mesos assigns tasks round-robin on offers") { val executorCores = 4 val maxCores = executorCores * 2 From 606432a13ad22d862c7cb5028ad6fe73c9985423 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Thu, 27 Apr 2017 20:48:43 +0000 Subject: [PATCH 368/512] [SPARK-20047][ML] Constrained Logistic Regression ## What changes were proposed in this pull request? MLlib ```LogisticRegression``` should support bound constrained optimization (only for L2 regularization). Users can add bound constraints to coefficients to make the solver produce solution in the specified range. Under the hood, we call Breeze [```L-BFGS-B```](https://github.com/scalanlp/breeze/blob/master/math/src/main/scala/breeze/optimize/LBFGSB.scala) as the solver for bound constrained optimization. But in the current breeze implementation, there are some bugs in L-BFGS-B, and https://github.com/scalanlp/breeze/pull/633 fixed them. We need to upgrade dependent breeze later, and currently we use the workaround L-BFGS-B in this PR temporary for reviewing. ## How was this patch tested? Unit tests. Author: Yanbo Liang Closes #17715 from yanboliang/spark-20047. --- .../classification/LogisticRegression.scala | 223 ++++++++- .../LogisticRegressionSuite.scala | 466 +++++++++++++++++- 2 files changed, 682 insertions(+), 7 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 44b3478e0c3d..d7dde329ed00 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -22,7 +22,7 @@ import java.util.Locale import scala.collection.mutable import breeze.linalg.{DenseVector => BDV} -import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS, OWLQN => BreezeOWLQN} +import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS, LBFGSB => BreezeLBFGSB, OWLQN => BreezeOWLQN} import org.apache.hadoop.fs.Path import org.apache.spark.SparkException @@ -178,11 +178,86 @@ private[classification] trait LogisticRegressionParams extends ProbabilisticClas } } + /** + * The lower bounds on coefficients if fitting under bound constrained optimization. + * The bound matrix must be compatible with the shape (1, number of features) for binomial + * regression, or (number of classes, number of features) for multinomial regression. + * Otherwise, it throws exception. + * + * @group param + */ + @Since("2.2.0") + val lowerBoundsOnCoefficients: Param[Matrix] = new Param(this, "lowerBoundsOnCoefficients", + "The lower bounds on coefficients if fitting under bound constrained optimization.") + + /** @group getParam */ + @Since("2.2.0") + def getLowerBoundsOnCoefficients: Matrix = $(lowerBoundsOnCoefficients) + + /** + * The upper bounds on coefficients if fitting under bound constrained optimization. + * The bound matrix must be compatible with the shape (1, number of features) for binomial + * regression, or (number of classes, number of features) for multinomial regression. + * Otherwise, it throws exception. + * + * @group param + */ + @Since("2.2.0") + val upperBoundsOnCoefficients: Param[Matrix] = new Param(this, "upperBoundsOnCoefficients", + "The upper bounds on coefficients if fitting under bound constrained optimization.") + + /** @group getParam */ + @Since("2.2.0") + def getUpperBoundsOnCoefficients: Matrix = $(upperBoundsOnCoefficients) + + /** + * The lower bounds on intercepts if fitting under bound constrained optimization. + * The bounds vector size must be equal with 1 for binomial regression, or the number + * of classes for multinomial regression. Otherwise, it throws exception. + * + * @group param + */ + @Since("2.2.0") + val lowerBoundsOnIntercepts: Param[Vector] = new Param(this, "lowerBoundsOnIntercepts", + "The lower bounds on intercepts if fitting under bound constrained optimization.") + + /** @group getParam */ + @Since("2.2.0") + def getLowerBoundsOnIntercepts: Vector = $(lowerBoundsOnIntercepts) + + /** + * The upper bounds on intercepts if fitting under bound constrained optimization. + * The bound vector size must be equal with 1 for binomial regression, or the number + * of classes for multinomial regression. Otherwise, it throws exception. + * + * @group param + */ + @Since("2.2.0") + val upperBoundsOnIntercepts: Param[Vector] = new Param(this, "upperBoundsOnIntercepts", + "The upper bounds on intercepts if fitting under bound constrained optimization.") + + /** @group getParam */ + @Since("2.2.0") + def getUpperBoundsOnIntercepts: Vector = $(upperBoundsOnIntercepts) + + protected def usingBoundConstrainedOptimization: Boolean = { + isSet(lowerBoundsOnCoefficients) || isSet(upperBoundsOnCoefficients) || + isSet(lowerBoundsOnIntercepts) || isSet(upperBoundsOnIntercepts) + } + override protected def validateAndTransformSchema( schema: StructType, fitting: Boolean, featuresDataType: DataType): StructType = { checkThresholdConsistency() + if (usingBoundConstrainedOptimization) { + require($(elasticNetParam) == 0.0, "Fitting under bound constrained optimization only " + + s"supports L2 regularization, but got elasticNetParam = $getElasticNetParam.") + } + if (!$(fitIntercept)) { + require(!isSet(lowerBoundsOnIntercepts) && !isSet(upperBoundsOnIntercepts), + "Pls don't set bounds on intercepts if fitting without intercept.") + } super.validateAndTransformSchema(schema, fitting, featuresDataType) } } @@ -217,6 +292,9 @@ class LogisticRegression @Since("1.2.0") ( * For alpha in (0,1), the penalty is a combination of L1 and L2. * Default is 0.0 which is an L2 penalty. * + * Note: Fitting under bound constrained optimization only supports L2 regularization, + * so throws exception if this param is non-zero value. + * * @group setParam */ @Since("1.4.0") @@ -312,6 +390,71 @@ class LogisticRegression @Since("1.2.0") ( def setAggregationDepth(value: Int): this.type = set(aggregationDepth, value) setDefault(aggregationDepth -> 2) + /** + * Set the lower bounds on coefficients if fitting under bound constrained optimization. + * + * @group setParam + */ + @Since("2.2.0") + def setLowerBoundsOnCoefficients(value: Matrix): this.type = set(lowerBoundsOnCoefficients, value) + + /** + * Set the upper bounds on coefficients if fitting under bound constrained optimization. + * + * @group setParam + */ + @Since("2.2.0") + def setUpperBoundsOnCoefficients(value: Matrix): this.type = set(upperBoundsOnCoefficients, value) + + /** + * Set the lower bounds on intercepts if fitting under bound constrained optimization. + * + * @group setParam + */ + @Since("2.2.0") + def setLowerBoundsOnIntercepts(value: Vector): this.type = set(lowerBoundsOnIntercepts, value) + + /** + * Set the upper bounds on intercepts if fitting under bound constrained optimization. + * + * @group setParam + */ + @Since("2.2.0") + def setUpperBoundsOnIntercepts(value: Vector): this.type = set(upperBoundsOnIntercepts, value) + + private def assertBoundConstrainedOptimizationParamsValid( + numCoefficientSets: Int, + numFeatures: Int): Unit = { + if (isSet(lowerBoundsOnCoefficients)) { + require($(lowerBoundsOnCoefficients).numRows == numCoefficientSets && + $(lowerBoundsOnCoefficients).numCols == numFeatures) + } + if (isSet(upperBoundsOnCoefficients)) { + require($(upperBoundsOnCoefficients).numRows == numCoefficientSets && + $(upperBoundsOnCoefficients).numCols == numFeatures) + } + if (isSet(lowerBoundsOnIntercepts)) { + require($(lowerBoundsOnIntercepts).size == numCoefficientSets) + } + if (isSet(upperBoundsOnIntercepts)) { + require($(upperBoundsOnIntercepts).size == numCoefficientSets) + } + if (isSet(lowerBoundsOnCoefficients) && isSet(upperBoundsOnCoefficients)) { + require($(lowerBoundsOnCoefficients).toArray.zip($(upperBoundsOnCoefficients).toArray) + .forall(x => x._1 <= x._2), "LowerBoundsOnCoefficients should always " + + "less than or equal to upperBoundsOnCoefficients, but found: " + + s"lowerBoundsOnCoefficients = $getLowerBoundsOnCoefficients, " + + s"upperBoundsOnCoefficients = $getUpperBoundsOnCoefficients.") + } + if (isSet(lowerBoundsOnIntercepts) && isSet(upperBoundsOnIntercepts)) { + require($(lowerBoundsOnIntercepts).toArray.zip($(upperBoundsOnIntercepts).toArray) + .forall(x => x._1 <= x._2), "LowerBoundsOnIntercepts should always " + + "less than or equal to upperBoundsOnIntercepts, but found: " + + s"lowerBoundsOnIntercepts = $getLowerBoundsOnIntercepts, " + + s"upperBoundsOnIntercepts = $getUpperBoundsOnIntercepts.") + } + } + private var optInitialModel: Option[LogisticRegressionModel] = None private[spark] def setInitialModel(model: LogisticRegressionModel): this.type = { @@ -378,6 +521,11 @@ class LogisticRegression @Since("1.2.0") ( } val numCoefficientSets = if (isMultinomial) numClasses else 1 + // Check params interaction is valid if fitting under bound constrained optimization. + if (usingBoundConstrainedOptimization) { + assertBoundConstrainedOptimizationParamsValid(numCoefficientSets, numFeatures) + } + if (isDefined(thresholds)) { require($(thresholds).length == numClasses, this.getClass.getSimpleName + ".train() called with non-matching numClasses and thresholds.length." + @@ -397,7 +545,7 @@ class LogisticRegression @Since("1.2.0") ( val isConstantLabel = histogram.count(_ != 0.0) == 1 - if ($(fitIntercept) && isConstantLabel) { + if ($(fitIntercept) && isConstantLabel && !usingBoundConstrainedOptimization) { logWarning(s"All labels are the same value and fitIntercept=true, so the coefficients " + s"will be zeros. Training is not needed.") val constantLabelIndex = Vectors.dense(histogram).argmax @@ -434,8 +582,53 @@ class LogisticRegression @Since("1.2.0") ( $(standardization), bcFeaturesStd, regParamL2, multinomial = isMultinomial, $(aggregationDepth)) + val numCoeffsPlusIntercepts = numFeaturesPlusIntercept * numCoefficientSets + + val (lowerBounds, upperBounds): (Array[Double], Array[Double]) = { + if (usingBoundConstrainedOptimization) { + val lowerBounds = Array.fill[Double](numCoeffsPlusIntercepts)(Double.NegativeInfinity) + val upperBounds = Array.fill[Double](numCoeffsPlusIntercepts)(Double.PositiveInfinity) + val isSetLowerBoundsOnCoefficients = isSet(lowerBoundsOnCoefficients) + val isSetUpperBoundsOnCoefficients = isSet(upperBoundsOnCoefficients) + val isSetLowerBoundsOnIntercepts = isSet(lowerBoundsOnIntercepts) + val isSetUpperBoundsOnIntercepts = isSet(upperBoundsOnIntercepts) + + var i = 0 + while (i < numCoeffsPlusIntercepts) { + val coefficientSetIndex = i % numCoefficientSets + val featureIndex = i / numCoefficientSets + if (featureIndex < numFeatures) { + if (isSetLowerBoundsOnCoefficients) { + lowerBounds(i) = $(lowerBoundsOnCoefficients)( + coefficientSetIndex, featureIndex) * featuresStd(featureIndex) + } + if (isSetUpperBoundsOnCoefficients) { + upperBounds(i) = $(upperBoundsOnCoefficients)( + coefficientSetIndex, featureIndex) * featuresStd(featureIndex) + } + } else { + if (isSetLowerBoundsOnIntercepts) { + lowerBounds(i) = $(lowerBoundsOnIntercepts)(coefficientSetIndex) + } + if (isSetUpperBoundsOnIntercepts) { + upperBounds(i) = $(upperBoundsOnIntercepts)(coefficientSetIndex) + } + } + i += 1 + } + (lowerBounds, upperBounds) + } else { + (null, null) + } + } + val optimizer = if ($(elasticNetParam) == 0.0 || $(regParam) == 0.0) { - new BreezeLBFGS[BDV[Double]]($(maxIter), 10, $(tol)) + if (lowerBounds != null && upperBounds != null) { + new BreezeLBFGSB( + BDV[Double](lowerBounds), BDV[Double](upperBounds), $(maxIter), 10, $(tol)) + } else { + new BreezeLBFGS[BDV[Double]]($(maxIter), 10, $(tol)) + } } else { val standardizationParam = $(standardization) def regParamL1Fun = (index: Int) => { @@ -546,6 +739,26 @@ class LogisticRegression @Since("1.2.0") ( math.log(histogram(1) / histogram(0))) } + if (usingBoundConstrainedOptimization) { + // Make sure all initial values locate in the corresponding bound. + var i = 0 + while (i < numCoeffsPlusIntercepts) { + val coefficientSetIndex = i % numCoefficientSets + val featureIndex = i / numCoefficientSets + if (initialCoefWithInterceptMatrix(coefficientSetIndex, featureIndex) < lowerBounds(i)) + { + initialCoefWithInterceptMatrix.update( + coefficientSetIndex, featureIndex, lowerBounds(i)) + } else if ( + initialCoefWithInterceptMatrix(coefficientSetIndex, featureIndex) > upperBounds(i)) + { + initialCoefWithInterceptMatrix.update( + coefficientSetIndex, featureIndex, upperBounds(i)) + } + i += 1 + } + } + val states = optimizer.iterations(new CachedDiffFunction(costFun), new BDV[Double](initialCoefWithInterceptMatrix.toArray)) @@ -599,7 +812,7 @@ class LogisticRegression @Since("1.2.0") ( if (isIntercept) interceptVec.toArray(classIndex) = value } - if ($(regParam) == 0.0 && isMultinomial) { + if ($(regParam) == 0.0 && isMultinomial && !usingBoundConstrainedOptimization) { /* When no regularization is applied, the multinomial coefficients lack identifiability because we do not use a pivot class. We can add any constant value to the coefficients @@ -620,7 +833,7 @@ class LogisticRegression @Since("1.2.0") ( } // center the intercepts when using multinomial algorithm - if ($(fitIntercept) && isMultinomial) { + if ($(fitIntercept) && isMultinomial && !usingBoundConstrainedOptimization) { val interceptArray = interceptVec.toArray val interceptMean = interceptArray.sum / interceptArray.length (0 until interceptVec.size).foreach { i => interceptArray(i) -= interceptMean } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index 83f575e83828..bf6bfe30bfe2 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -26,7 +26,7 @@ import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.attribute.NominalAttribute import org.apache.spark.ml.classification.LogisticRegressionSuite._ import org.apache.spark.ml.feature.{Instance, LabeledPoint} -import org.apache.spark.ml.linalg.{DenseMatrix, Matrices, SparseMatrix, Vector, Vectors} +import org.apache.spark.ml.linalg.{DenseMatrix, Matrices, Matrix, SparseMatrix, Vector, Vectors} import org.apache.spark.ml.param.{ParamMap, ParamsSuite} import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ @@ -150,6 +150,54 @@ class LogisticRegressionSuite assert(!model.hasSummary) } + test("logistic regression: illegal params") { + val lowerBoundsOnCoefficients = Matrices.dense(1, 4, Array(1.0, 0.0, 1.0, 0.0)) + val upperBoundsOnCoefficients1 = Matrices.dense(1, 4, Array(0.0, 1.0, 1.0, 0.0)) + val upperBoundsOnCoefficients2 = Matrices.dense(1, 3, Array(1.0, 0.0, 1.0)) + val lowerBoundsOnIntercepts = Vectors.dense(1.0) + + // Work well when only set bound in one side. + new LogisticRegression() + .setLowerBoundsOnCoefficients(lowerBoundsOnCoefficients) + .fit(binaryDataset) + + withClue("bound constrained optimization only supports L2 regularization") { + intercept[IllegalArgumentException] { + new LogisticRegression() + .setLowerBoundsOnCoefficients(lowerBoundsOnCoefficients) + .setElasticNetParam(1.0) + .fit(binaryDataset) + } + } + + withClue("lowerBoundsOnCoefficients should less than or equal to upperBoundsOnCoefficients") { + intercept[IllegalArgumentException] { + new LogisticRegression() + .setLowerBoundsOnCoefficients(lowerBoundsOnCoefficients) + .setUpperBoundsOnCoefficients(upperBoundsOnCoefficients1) + .fit(binaryDataset) + } + } + + withClue("the coefficients bound matrix mismatched with shape (1, number of features)") { + intercept[IllegalArgumentException] { + new LogisticRegression() + .setLowerBoundsOnCoefficients(lowerBoundsOnCoefficients) + .setUpperBoundsOnCoefficients(upperBoundsOnCoefficients2) + .fit(binaryDataset) + } + } + + withClue("bounds on intercepts should not be set if fitting without intercept") { + intercept[IllegalArgumentException] { + new LogisticRegression() + .setLowerBoundsOnIntercepts(lowerBoundsOnIntercepts) + .setFitIntercept(false) + .fit(binaryDataset) + } + } + } + test("empty probabilityCol") { val lr = new LogisticRegression().setProbabilityCol("") val model = lr.fit(smallBinaryDataset) @@ -610,6 +658,107 @@ class LogisticRegressionSuite assert(model2.coefficients ~= coefficientsR relTol 1E-3) } + test("binary logistic regression with intercept without regularization with bound") { + // Bound constrained optimization with bound on one side. + val upperBoundsOnCoefficients = Matrices.dense(1, 4, Array(1.0, 0.0, 1.0, 0.0)) + val upperBoundsOnIntercepts = Vectors.dense(1.0) + + val trainer1 = new LogisticRegression() + .setUpperBoundsOnCoefficients(upperBoundsOnCoefficients) + .setUpperBoundsOnIntercepts(upperBoundsOnIntercepts) + .setFitIntercept(true) + .setStandardization(true) + .setWeightCol("weight") + val trainer2 = new LogisticRegression() + .setUpperBoundsOnCoefficients(upperBoundsOnCoefficients) + .setUpperBoundsOnIntercepts(upperBoundsOnIntercepts) + .setFitIntercept(true) + .setStandardization(false) + .setWeightCol("weight") + + val model1 = trainer1.fit(binaryDataset) + val model2 = trainer2.fit(binaryDataset) + + // The solution is generated by https://github.com/yanboliang/bound-optimization. + val coefficientsExpected1 = Vectors.dense(0.06079437, 0.0, -0.26351059, -0.59102199) + val interceptExpected1 = 1.0 + + assert(model1.intercept ~== interceptExpected1 relTol 1E-3) + assert(model1.coefficients ~= coefficientsExpected1 relTol 1E-3) + + // Without regularization, with or without standardization will converge to the same solution. + assert(model2.intercept ~== interceptExpected1 relTol 1E-3) + assert(model2.coefficients ~= coefficientsExpected1 relTol 1E-3) + + // Bound constrained optimization with bound on both side. + val lowerBoundsOnCoefficients = Matrices.dense(1, 4, Array(0.0, -1.0, 0.0, -1.0)) + val lowerBoundsOnIntercepts = Vectors.dense(0.0) + + val trainer3 = new LogisticRegression() + .setUpperBoundsOnCoefficients(upperBoundsOnCoefficients) + .setUpperBoundsOnIntercepts(upperBoundsOnIntercepts) + .setLowerBoundsOnCoefficients(lowerBoundsOnCoefficients) + .setLowerBoundsOnIntercepts(lowerBoundsOnIntercepts) + .setFitIntercept(true) + .setStandardization(true) + .setWeightCol("weight") + val trainer4 = new LogisticRegression() + .setUpperBoundsOnCoefficients(upperBoundsOnCoefficients) + .setUpperBoundsOnIntercepts(upperBoundsOnIntercepts) + .setLowerBoundsOnCoefficients(lowerBoundsOnCoefficients) + .setLowerBoundsOnIntercepts(lowerBoundsOnIntercepts) + .setFitIntercept(true) + .setStandardization(false) + .setWeightCol("weight") + + val model3 = trainer3.fit(binaryDataset) + val model4 = trainer4.fit(binaryDataset) + + // The solution is generated by https://github.com/yanboliang/bound-optimization. + val coefficientsExpected3 = Vectors.dense(0.0, 0.0, 0.0, -0.71708632) + val interceptExpected3 = 0.58776113 + + assert(model3.intercept ~== interceptExpected3 relTol 1E-3) + assert(model3.coefficients ~= coefficientsExpected3 relTol 1E-3) + + // Without regularization, with or without standardization will converge to the same solution. + assert(model4.intercept ~== interceptExpected3 relTol 1E-3) + assert(model4.coefficients ~= coefficientsExpected3 relTol 1E-3) + + // Bound constrained optimization with infinite bound on both side. + val trainer5 = new LogisticRegression() + .setUpperBoundsOnCoefficients(Matrices.dense(1, 4, Array.fill(4)(Double.PositiveInfinity))) + .setUpperBoundsOnIntercepts(Vectors.dense(Double.PositiveInfinity)) + .setLowerBoundsOnCoefficients(Matrices.dense(1, 4, Array.fill(4)(Double.NegativeInfinity))) + .setLowerBoundsOnIntercepts(Vectors.dense(Double.NegativeInfinity)) + .setFitIntercept(true) + .setStandardization(true) + .setWeightCol("weight") + val trainer6 = new LogisticRegression() + .setUpperBoundsOnCoefficients(Matrices.dense(1, 4, Array.fill(4)(Double.PositiveInfinity))) + .setUpperBoundsOnIntercepts(Vectors.dense(Double.PositiveInfinity)) + .setLowerBoundsOnCoefficients(Matrices.dense(1, 4, Array.fill(4)(Double.NegativeInfinity))) + .setLowerBoundsOnIntercepts(Vectors.dense(Double.NegativeInfinity)) + .setFitIntercept(true) + .setStandardization(false) + .setWeightCol("weight") + + val model5 = trainer5.fit(binaryDataset) + val model6 = trainer6.fit(binaryDataset) + + // The solution is generated by https://github.com/yanboliang/bound-optimization. + // It should be same as unbound constrained optimization with LBFGS. + val coefficientsExpected5 = Vectors.dense(-0.5734389, 0.8911736, -0.3878645, -0.8060570) + val interceptExpected5 = 2.7355261 + + assert(model5.intercept ~== interceptExpected5 relTol 1E-3) + assert(model5.coefficients ~= coefficientsExpected5 relTol 1E-3) + + // Without regularization, with or without standardization will converge to the same solution. + assert(model6.intercept ~== interceptExpected5 relTol 1E-3) + assert(model6.coefficients ~= coefficientsExpected5 relTol 1E-3) + } + test("binary logistic regression without intercept without regularization") { val trainer1 = (new LogisticRegression).setFitIntercept(false).setStandardization(true) .setWeightCol("weight") @@ -650,6 +799,34 @@ class LogisticRegressionSuite assert(model2.coefficients ~= coefficientsR relTol 1E-2) } + test("binary logistic regression without intercept without regularization with bound") { + val upperBoundsOnCoefficients = Matrices.dense(1, 4, Array(1.0, 0.0, 1.0, 0.0)).toSparse + + val trainer1 = new LogisticRegression() + .setUpperBoundsOnCoefficients(upperBoundsOnCoefficients) + .setFitIntercept(false) + .setStandardization(true) + .setWeightCol("weight") + val trainer2 = new LogisticRegression() + .setUpperBoundsOnCoefficients(upperBoundsOnCoefficients) + .setFitIntercept(false) + .setStandardization(false) + .setWeightCol("weight") + + val model1 = trainer1.fit(binaryDataset) + val model2 = trainer2.fit(binaryDataset) + + // The solution is generated by https://github.com/yanboliang/bound-optimization. + val coefficientsExpected = Vectors.dense(0.20847553, 0.0, -0.24240289, -0.55568071) + + assert(model1.intercept ~== 0.0 relTol 1E-3) + assert(model1.coefficients ~= coefficientsExpected relTol 1E-3) + + // Without regularization, with or without standardization will converge to the same solution. + assert(model2.intercept ~== 0.0 relTol 1E-3) + assert(model2.coefficients ~= coefficientsExpected relTol 1E-3) + } + test("binary logistic regression with intercept with L1 regularization") { val trainer1 = (new LogisticRegression).setFitIntercept(true) .setElasticNetParam(1.0).setRegParam(0.12).setStandardization(true).setWeightCol("weight") @@ -815,6 +992,40 @@ class LogisticRegressionSuite assert(model2.coefficients ~= coefficientsR relTol 1E-3) } + test("binary logistic regression with intercept with L2 regularization with bound") { + val upperBoundsOnCoefficients = Matrices.dense(1, 4, Array(1.0, 0.0, 1.0, 0.0)) + val upperBoundsOnIntercepts = Vectors.dense(1.0) + + val trainer1 = new LogisticRegression() + .setUpperBoundsOnCoefficients(upperBoundsOnCoefficients) + .setUpperBoundsOnIntercepts(upperBoundsOnIntercepts) + .setRegParam(1.37) + .setFitIntercept(true) + .setStandardization(true) + .setWeightCol("weight") + val trainer2 = new LogisticRegression() + .setUpperBoundsOnCoefficients(upperBoundsOnCoefficients) + .setUpperBoundsOnIntercepts(upperBoundsOnIntercepts) + .setRegParam(1.37) + .setFitIntercept(true) + .setStandardization(false) + .setWeightCol("weight") + + val model1 = trainer1.fit(binaryDataset) + val model2 = trainer2.fit(binaryDataset) + + // The solution is generated by https://github.com/yanboliang/bound-optimization. + val coefficientsExpectedWithStd = Vectors.dense(-0.06985003, 0.0, -0.04794278, -0.10168595) + val interceptExpectedWithStd = 0.45750141 + val coefficientsExpected = Vectors.dense(-0.0494524, 0.0, -0.11360797, -0.06313577) + val interceptExpected = 0.53722967 + + assert(model1.intercept ~== interceptExpectedWithStd relTol 1E-3) + assert(model1.coefficients ~= coefficientsExpectedWithStd relTol 1E-3) + assert(model2.intercept ~== interceptExpected relTol 1E-3) + assert(model2.coefficients ~= coefficientsExpected relTol 1E-3) + } + test("binary logistic regression without intercept with L2 regularization") { val trainer1 = (new LogisticRegression).setFitIntercept(false) .setElasticNetParam(0.0).setRegParam(1.37).setStandardization(true).setWeightCol("weight") @@ -864,6 +1075,35 @@ class LogisticRegressionSuite assert(model2.coefficients ~= coefficientsR relTol 1E-2) } + test("binary logistic regression without intercept with L2 regularization with bound") { + val upperBoundsOnCoefficients = Matrices.dense(1, 4, Array(1.0, 0.0, 1.0, 0.0)) + + val trainer1 = new LogisticRegression() + .setUpperBoundsOnCoefficients(upperBoundsOnCoefficients) + .setRegParam(1.37) + .setFitIntercept(false) + .setStandardization(true) + .setWeightCol("weight") + val trainer2 = new LogisticRegression() + .setUpperBoundsOnCoefficients(upperBoundsOnCoefficients) + .setRegParam(1.37) + .setFitIntercept(false) + .setStandardization(false) + .setWeightCol("weight") + + val model1 = trainer1.fit(binaryDataset) + val model2 = trainer2.fit(binaryDataset) + + // The solution is generated by https://github.com/yanboliang/bound-optimization. + val coefficientsExpectedWithStd = Vectors.dense(-0.00796538, 0.0, -0.0394228, -0.0873314) + val coefficientsExpected = Vectors.dense(0.01105972, 0.0, -0.08574949, -0.05079558) + + assert(model1.intercept ~== 0.0 relTol 1E-3) + assert(model1.coefficients ~= coefficientsExpectedWithStd relTol 1E-3) + assert(model2.intercept ~== 0.0 relTol 1E-3) + assert(model2.coefficients ~= coefficientsExpected relTol 1E-3) + } + test("binary logistic regression with intercept with ElasticNet regularization") { val trainer1 = (new LogisticRegression).setFitIntercept(true).setMaxIter(200) .setElasticNetParam(0.38).setRegParam(0.21).setStandardization(true).setWeightCol("weight") @@ -1084,7 +1324,6 @@ class LogisticRegressionSuite } test("multinomial logistic regression with intercept without regularization") { - val trainer1 = (new LogisticRegression).setFitIntercept(true) .setElasticNetParam(0.0).setRegParam(0.0).setStandardization(true).setWeightCol("weight") val trainer2 = (new LogisticRegression).setFitIntercept(true) @@ -1152,6 +1391,110 @@ class LogisticRegressionSuite assert(model2.interceptVector.toArray.sum ~== 0.0 absTol eps) } + test("multinomial logistic regression with intercept without regularization with bound") { + // Bound constrained optimization with bound on one side. + val lowerBoundsOnCoefficients = Matrices.dense(3, 4, Array.fill(12)(1.0)) + val lowerBoundsOnIntercepts = Vectors.dense(Array.fill(3)(1.0)) + + val trainer1 = new LogisticRegression() + .setLowerBoundsOnCoefficients(lowerBoundsOnCoefficients) + .setLowerBoundsOnIntercepts(lowerBoundsOnIntercepts) + .setFitIntercept(true) + .setStandardization(true) + .setWeightCol("weight") + val trainer2 = new LogisticRegression() + .setLowerBoundsOnCoefficients(lowerBoundsOnCoefficients) + .setLowerBoundsOnIntercepts(lowerBoundsOnIntercepts) + .setFitIntercept(true) + .setStandardization(false) + .setWeightCol("weight") + + val model1 = trainer1.fit(multinomialDataset) + val model2 = trainer2.fit(multinomialDataset) + + // The solution is generated by https://github.com/yanboliang/bound-optimization. + val coefficientsExpected1 = new DenseMatrix(3, 4, Array( + 2.52076464, 2.73596057, 1.87984904, 2.73264492, + 1.93302281, 3.71363303, 1.50681746, 1.93398782, + 2.37839917, 1.93601818, 1.81924758, 2.45191255), isTransposed = true) + val interceptsExpected1 = Vectors.dense(1.00010477, 3.44237083, 4.86740286) + + checkCoefficientsEquivalent(model1.coefficientMatrix, coefficientsExpected1) + assert(model1.interceptVector ~== interceptsExpected1 relTol 0.01) + checkCoefficientsEquivalent(model2.coefficientMatrix, coefficientsExpected1) + assert(model2.interceptVector ~== interceptsExpected1 relTol 0.01) + + // Bound constrained optimization with bound on both side. + val upperBoundsOnCoefficients = Matrices.dense(3, 4, Array.fill(12)(2.0)) + val upperBoundsOnIntercepts = Vectors.dense(Array.fill(3)(2.0)) + + val trainer3 = new LogisticRegression() + .setLowerBoundsOnCoefficients(lowerBoundsOnCoefficients) + .setLowerBoundsOnIntercepts(lowerBoundsOnIntercepts) + .setUpperBoundsOnCoefficients(upperBoundsOnCoefficients) + .setUpperBoundsOnIntercepts(upperBoundsOnIntercepts) + .setFitIntercept(true) + .setStandardization(true) + .setWeightCol("weight") + val trainer4 = new LogisticRegression() + .setLowerBoundsOnCoefficients(lowerBoundsOnCoefficients) + .setLowerBoundsOnIntercepts(lowerBoundsOnIntercepts) + .setUpperBoundsOnCoefficients(upperBoundsOnCoefficients) + .setUpperBoundsOnIntercepts(upperBoundsOnIntercepts) + .setFitIntercept(true) + .setStandardization(false) + .setWeightCol("weight") + + val model3 = trainer3.fit(multinomialDataset) + val model4 = trainer4.fit(multinomialDataset) + + // The solution is generated by https://github.com/yanboliang/bound-optimization. + val coefficientsExpected3 = new DenseMatrix(3, 4, Array( + 1.61967097, 1.16027835, 1.45131448, 1.97390431, + 1.30529317, 2.0, 1.12985473, 1.26652854, + 1.61647195, 1.0, 1.40642959, 1.72985589), isTransposed = true) + val interceptsExpected3 = Vectors.dense(1.0, 2.0, 2.0) + + checkCoefficientsEquivalent(model3.coefficientMatrix, coefficientsExpected3) + assert(model3.interceptVector ~== interceptsExpected3 relTol 0.01) + checkCoefficientsEquivalent(model4.coefficientMatrix, coefficientsExpected3) + assert(model4.interceptVector ~== interceptsExpected3 relTol 0.01) + + // Bound constrained optimization with infinite bound on both side. + val trainer5 = new LogisticRegression() + .setLowerBoundsOnCoefficients(Matrices.dense(3, 4, Array.fill(12)(Double.NegativeInfinity))) + .setLowerBoundsOnIntercepts(Vectors.dense(Array.fill(3)(Double.NegativeInfinity))) + .setUpperBoundsOnCoefficients(Matrices.dense(3, 4, Array.fill(12)(Double.PositiveInfinity))) + .setUpperBoundsOnIntercepts(Vectors.dense(Array.fill(3)(Double.PositiveInfinity))) + .setFitIntercept(true) + .setStandardization(true) + .setWeightCol("weight") + val trainer6 = new LogisticRegression() + .setLowerBoundsOnCoefficients(Matrices.dense(3, 4, Array.fill(12)(Double.NegativeInfinity))) + .setLowerBoundsOnIntercepts(Vectors.dense(Array.fill(3)(Double.NegativeInfinity))) + .setUpperBoundsOnCoefficients(Matrices.dense(3, 4, Array.fill(12)(Double.PositiveInfinity))) + .setUpperBoundsOnIntercepts(Vectors.dense(Array.fill(3)(Double.PositiveInfinity))) + .setFitIntercept(true) + .setStandardization(false) + .setWeightCol("weight") + + val model5 = trainer5.fit(multinomialDataset) + val model6 = trainer6.fit(multinomialDataset) + + // The solution is generated by https://github.com/yanboliang/bound-optimization. + // It should be same as unbound constrained optimization with LBFGS. + val coefficientsExpected5 = new DenseMatrix(3, 4, Array( + 0.24337896, -0.05916156, 0.14446790, 0.35976165, + -0.3443375, 0.9181331, -0.2283959, -0.4388066, + 0.10095851, -0.85897154, 0.08392798, 0.07904499), isTransposed = true) + val interceptsExpected5 = Vectors.dense(-2.10320093, 0.3394473, 1.76375361) + + checkCoefficientsEquivalent(model5.coefficientMatrix, coefficientsExpected5) + assert(model5.interceptVector ~== interceptsExpected5 relTol 0.01) + checkCoefficientsEquivalent(model6.coefficientMatrix, coefficientsExpected5) + assert(model6.interceptVector ~== interceptsExpected5 relTol 0.01) + } + test("multinomial logistic regression without intercept without regularization") { val trainer1 = (new LogisticRegression).setFitIntercept(false) @@ -1220,6 +1563,35 @@ class LogisticRegressionSuite assert(model2.interceptVector.toArray.sum ~== 0.0 absTol eps) } + test("multinomial logistic regression without intercept without regularization with bound") { + val lowerBoundsOnCoefficients = Matrices.dense(3, 4, Array.fill(12)(1.0)) + + val trainer1 = new LogisticRegression() + .setLowerBoundsOnCoefficients(lowerBoundsOnCoefficients) + .setFitIntercept(false) + .setStandardization(true) + .setWeightCol("weight") + val trainer2 = new LogisticRegression() + .setLowerBoundsOnCoefficients(lowerBoundsOnCoefficients) + .setFitIntercept(false) + .setStandardization(false) + .setWeightCol("weight") + + val model1 = trainer1.fit(multinomialDataset) + val model2 = trainer2.fit(multinomialDataset) + + // The solution is generated by https://github.com/yanboliang/bound-optimization. + val coefficientsExpected = new DenseMatrix(3, 4, Array( + 1.62410051, 1.38219391, 1.34486618, 1.74641729, + 1.23058989, 2.71787825, 1.0, 1.00007073, + 1.79478632, 1.14360459, 1.33011603, 1.55093897), isTransposed = true) + + checkCoefficientsEquivalent(model1.coefficientMatrix, coefficientsExpected) + assert(model1.interceptVector.toArray === Array.fill(3)(0.0)) + checkCoefficientsEquivalent(model2.coefficientMatrix, coefficientsExpected) + assert(model2.interceptVector.toArray === Array.fill(3)(0.0)) + } + test("multinomial logistic regression with intercept with L1 regularization") { // use tighter constraints because OWL-QN solver takes longer to converge @@ -1518,6 +1890,46 @@ class LogisticRegressionSuite assert(model2.interceptVector.toArray.sum ~== 0.0 absTol eps) } + test("multinomial logistic regression with intercept with L2 regularization with bound") { + val lowerBoundsOnCoefficients = Matrices.dense(3, 4, Array.fill(12)(1.0)) + val lowerBoundsOnIntercepts = Vectors.dense(Array.fill(3)(1.0)) + + val trainer1 = new LogisticRegression() + .setLowerBoundsOnCoefficients(lowerBoundsOnCoefficients) + .setLowerBoundsOnIntercepts(lowerBoundsOnIntercepts) + .setRegParam(0.1) + .setFitIntercept(true) + .setStandardization(true) + .setWeightCol("weight") + val trainer2 = new LogisticRegression() + .setLowerBoundsOnCoefficients(lowerBoundsOnCoefficients) + .setLowerBoundsOnIntercepts(lowerBoundsOnIntercepts) + .setRegParam(0.1) + .setFitIntercept(true) + .setStandardization(false) + .setWeightCol("weight") + + val model1 = trainer1.fit(multinomialDataset) + val model2 = trainer2.fit(multinomialDataset) + + // The solution is generated by https://github.com/yanboliang/bound-optimization. + val coefficientsExpectedWithStd = new DenseMatrix(3, 4, Array( + 1.0, 1.0, 1.0, 1.01647497, + 1.0, 1.44105616, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0), isTransposed = true) + val interceptsExpectedWithStd = Vectors.dense(2.52055893, 1.0, 2.560682) + val coefficientsExpected = new DenseMatrix(3, 4, Array( + 1.0, 1.0, 1.03189386, 1.0, + 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0), isTransposed = true) + val interceptsExpected = Vectors.dense(1.06418835, 1.0, 1.20494701) + + assert(model1.coefficientMatrix ~== coefficientsExpectedWithStd relTol 0.01) + assert(model1.interceptVector ~== interceptsExpectedWithStd relTol 0.01) + assert(model2.coefficientMatrix ~== coefficientsExpected relTol 0.01) + assert(model2.interceptVector ~== interceptsExpected relTol 0.01) + } + test("multinomial logistic regression without intercept with L2 regularization") { val trainer1 = (new LogisticRegression).setFitIntercept(false) .setElasticNetParam(0.0).setRegParam(0.1).setStandardization(true).setWeightCol("weight") @@ -1615,6 +2027,41 @@ class LogisticRegressionSuite assert(model2.interceptVector.toArray.sum ~== 0.0 absTol eps) } + test("multinomial logistic regression without intercept with L2 regularization with bound") { + val lowerBoundsOnCoefficients = Matrices.dense(3, 4, Array.fill(12)(1.0)) + + val trainer1 = new LogisticRegression() + .setLowerBoundsOnCoefficients(lowerBoundsOnCoefficients) + .setRegParam(0.1) + .setFitIntercept(false) + .setStandardization(true) + .setWeightCol("weight") + val trainer2 = new LogisticRegression() + .setLowerBoundsOnCoefficients(lowerBoundsOnCoefficients) + .setRegParam(0.1) + .setFitIntercept(false) + .setStandardization(false) + .setWeightCol("weight") + + val model1 = trainer1.fit(multinomialDataset) + val model2 = trainer2.fit(multinomialDataset) + + // The solution is generated by https://github.com/yanboliang/bound-optimization. + val coefficientsExpectedWithStd = new DenseMatrix(3, 4, Array( + 1.01324653, 1.0, 1.0, 1.0415767, + 1.0, 1.0, 1.0, 1.0, + 1.02244888, 1.0, 1.0, 1.0), isTransposed = true) + val coefficientsExpected = new DenseMatrix(3, 4, Array( + 1.0, 1.0, 1.03932259, 1.0, + 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.03274649, 1.0), isTransposed = true) + + assert(model1.coefficientMatrix ~== coefficientsExpectedWithStd absTol 0.01) + assert(model1.interceptVector.toArray === Array.fill(3)(0.0)) + assert(model2.coefficientMatrix ~== coefficientsExpected absTol 0.01) + assert(model2.interceptVector.toArray === Array.fill(3)(0.0)) + } + test("multinomial logistic regression with intercept with elasticnet regularization") { val trainer1 = (new LogisticRegression).setFitIntercept(true).setWeightCol("weight") .setElasticNetParam(0.5).setRegParam(0.1).setStandardization(true) @@ -2273,4 +2720,19 @@ object LogisticRegressionSuite { val testData = (0 until nPoints).map(i => LabeledPoint(y(i), x(i))) testData } + + /** + * When no regularization is applied, the multinomial coefficients lack identifiability + * because we do not use a pivot class. We can add any constant value to the coefficients + * and get the same likelihood. If fitting under bound constrained optimization, we don't + * choose the mean centered coefficients like what we do for unbound problems, since they + * may out of the bounds. We use this function to check whether two coefficients are equivalent. + */ + def checkCoefficientsEquivalent(coefficients1: Matrix, coefficients2: Matrix): Unit = { + coefficients1.colIter.zip(coefficients2.colIter).foreach { case (col1: Vector, col2: Vector) => + (col1.asBreeze - col2.asBreeze).toArray.toSeq.sliding(2).foreach { + case Seq(v1, v2) => assert(v1 ~= v2 absTol 1E-3) + } + } + } } From 01c999e7f94d5e6c2fce67304dc62351dfbdf963 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Thu, 27 Apr 2017 13:55:03 -0700 Subject: [PATCH 369/512] [SPARK-20461][CORE][SS] Use UninterruptibleThread for Executor and fix the potential hang in CachedKafkaConsumer ## What changes were proposed in this pull request? This PR changes Executor's threads to `UninterruptibleThread` so that we can use `runUninterruptibly` in `CachedKafkaConsumer`. However, this is just best effort to avoid hanging forever. If the user uses`CachedKafkaConsumer` in another thread (e.g., create a new thread or Future), the potential hang may still happen. ## How was this patch tested? The new added test. Author: Shixiong Zhu Closes #17761 from zsxwing/int. --- .../org/apache/spark/executor/Executor.scala | 19 +++++++++++++++++-- .../spark/util/UninterruptibleThread.scala | 8 +++++++- .../apache/spark/executor/ExecutorSuite.scala | 13 +++++++++++++ .../sql/kafka010/CachedKafkaConsumer.scala | 15 +++++++++++++-- 4 files changed, 50 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 18f04391d64c..51b6c373c4da 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -23,13 +23,15 @@ import java.lang.management.ManagementFactory import java.net.{URI, URL} import java.nio.ByteBuffer import java.util.Properties -import java.util.concurrent.{ConcurrentHashMap, TimeUnit} +import java.util.concurrent._ import javax.annotation.concurrent.GuardedBy import scala.collection.JavaConverters._ import scala.collection.mutable.{ArrayBuffer, HashMap, Map} import scala.util.control.NonFatal +import com.google.common.util.concurrent.ThreadFactoryBuilder + import org.apache.spark._ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging @@ -84,7 +86,20 @@ private[spark] class Executor( } // Start worker thread pool - private val threadPool = ThreadUtils.newDaemonCachedThreadPool("Executor task launch worker") + private val threadPool = { + val threadFactory = new ThreadFactoryBuilder() + .setDaemon(true) + .setNameFormat("Executor task launch worker-%d") + .setThreadFactory(new ThreadFactory { + override def newThread(r: Runnable): Thread = + // Use UninterruptibleThread to run tasks so that we can allow running codes without being + // interrupted by `Thread.interrupt()`. Some issues, such as KAFKA-1894, HADOOP-10622, + // will hang forever if some methods are interrupted. + new UninterruptibleThread(r, "unused") // thread name will be set by ThreadFactoryBuilder + }) + .build() + Executors.newCachedThreadPool(threadFactory).asInstanceOf[ThreadPoolExecutor] + } private val executorSource = new ExecutorSource(threadPool, executorId) // Pool used for threads that supervise task killing / cancellation private val taskReaperPool = ThreadUtils.newDaemonCachedThreadPool("Task reaper") diff --git a/core/src/main/scala/org/apache/spark/util/UninterruptibleThread.scala b/core/src/main/scala/org/apache/spark/util/UninterruptibleThread.scala index f0b68f0cb7e2..27922b31949b 100644 --- a/core/src/main/scala/org/apache/spark/util/UninterruptibleThread.scala +++ b/core/src/main/scala/org/apache/spark/util/UninterruptibleThread.scala @@ -27,7 +27,13 @@ import javax.annotation.concurrent.GuardedBy * * Note: "runUninterruptibly" should be called only in `this` thread. */ -private[spark] class UninterruptibleThread(name: String) extends Thread(name) { +private[spark] class UninterruptibleThread( + target: Runnable, + name: String) extends Thread(target, name) { + + def this(name: String) { + this(null, name) + } /** A monitor to protect "uninterruptible" and "interrupted" */ private val uninterruptibleLock = new Object diff --git a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala index f47e574b4fc4..efcad140350b 100644 --- a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala +++ b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala @@ -44,6 +44,7 @@ import org.apache.spark.scheduler.{FakeTask, ResultTask, TaskDescription} import org.apache.spark.serializer.JavaSerializer import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.storage.BlockManagerId +import org.apache.spark.util.UninterruptibleThread class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSugar with Eventually { @@ -158,6 +159,18 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug assert(failReason.isInstanceOf[FetchFailed]) } + test("Executor's worker threads should be UninterruptibleThread") { + val conf = new SparkConf() + .setMaster("local") + .setAppName("executor thread test") + .set("spark.ui.enabled", "false") + sc = new SparkContext(conf) + val executorThread = sc.parallelize(Seq(1), 1).map { _ => + Thread.currentThread.getClass.getName + }.collect().head + assert(executorThread === classOf[UninterruptibleThread].getName) + } + test("SPARK-19276: OOMs correctly handled with a FetchFailure") { // when there is a fatal error like an OOM, we don't do normal fetch failure handling, since it // may be a false positive. And we should call the uncaught exception handler. diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumer.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumer.scala index 6d76904fb0e5..bf6c0900c97e 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumer.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumer.scala @@ -28,6 +28,7 @@ import org.apache.kafka.common.TopicPartition import org.apache.spark.{SparkEnv, SparkException, TaskContext} import org.apache.spark.internal.Logging import org.apache.spark.sql.kafka010.KafkaSource._ +import org.apache.spark.util.UninterruptibleThread /** @@ -62,11 +63,20 @@ private[kafka010] case class CachedKafkaConsumer private( case class AvailableOffsetRange(earliest: Long, latest: Long) + private def runUninterruptiblyIfPossible[T](body: => T): T = Thread.currentThread match { + case ut: UninterruptibleThread => + ut.runUninterruptibly(body) + case _ => + logWarning("CachedKafkaConsumer is not running in UninterruptibleThread. " + + "It may hang when CachedKafkaConsumer's methods are interrupted because of KAFKA-1894") + body + } + /** * Return the available offset range of the current partition. It's a pair of the earliest offset * and the latest offset. */ - def getAvailableOffsetRange(): AvailableOffsetRange = { + def getAvailableOffsetRange(): AvailableOffsetRange = runUninterruptiblyIfPossible { consumer.seekToBeginning(Set(topicPartition).asJava) val earliestOffset = consumer.position(topicPartition) consumer.seekToEnd(Set(topicPartition).asJava) @@ -92,7 +102,8 @@ private[kafka010] case class CachedKafkaConsumer private( offset: Long, untilOffset: Long, pollTimeoutMs: Long, - failOnDataLoss: Boolean): ConsumerRecord[Array[Byte], Array[Byte]] = { + failOnDataLoss: Boolean): + ConsumerRecord[Array[Byte], Array[Byte]] = runUninterruptiblyIfPossible { require(offset < untilOffset, s"offset must always be less than untilOffset [offset: $offset, untilOffset: $untilOffset]") logDebug(s"Get $groupId $topicPartition nextOffset $nextOffsetInFetchedData requested $offset") From 823baca2cb8edb62885af547d3511c9e8923cefd Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Thu, 27 Apr 2017 13:58:44 -0700 Subject: [PATCH 370/512] [SPARK-20452][SS][KAFKA] Fix a potential ConcurrentModificationException for batch Kafka DataFrame ## What changes were proposed in this pull request? Cancel a batch Kafka query but one of task cannot be cancelled, and rerun the same DataFrame may cause ConcurrentModificationException because it may launch two tasks sharing the same group id. This PR always create a new consumer when `reuseKafkaConsumer = false` to avoid ConcurrentModificationException. It also contains other minor fixes. ## How was this patch tested? Jenkins. Author: Shixiong Zhu Closes #17752 from zsxwing/kafka-fix. --- .../sql/kafka010/CachedKafkaConsumer.scala | 12 +- .../sql/kafka010/KafkaOffsetReader.scala | 6 +- .../spark/sql/kafka010/KafkaRelation.scala | 30 +++- .../sql/kafka010/KafkaSourceProvider.scala | 147 ++++++++---------- .../spark/sql/kafka010/KafkaSourceRDD.scala | 19 ++- .../spark/streaming/kafka010/KafkaRDD.scala | 2 +- 6 files changed, 119 insertions(+), 97 deletions(-) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumer.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumer.scala index bf6c0900c97e..7c4f38e02fb2 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumer.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumer.scala @@ -287,7 +287,7 @@ private[kafka010] case class CachedKafkaConsumer private( reportDataLoss0(failOnDataLoss, finalMessage, cause) } - private def close(): Unit = consumer.close() + def close(): Unit = consumer.close() private def seek(offset: Long): Unit = { logDebug(s"Seeking to $groupId $topicPartition $offset") @@ -382,7 +382,7 @@ private[kafka010] object CachedKafkaConsumer extends Logging { // If this is reattempt at running the task, then invalidate cache and start with // a new consumer - if (TaskContext.get != null && TaskContext.get.attemptNumber > 1) { + if (TaskContext.get != null && TaskContext.get.attemptNumber >= 1) { removeKafkaConsumer(topic, partition, kafkaParams) val consumer = new CachedKafkaConsumer(topicPartition, kafkaParams) consumer.inuse = true @@ -398,6 +398,14 @@ private[kafka010] object CachedKafkaConsumer extends Logging { } } + /** Create an [[CachedKafkaConsumer]] but don't put it into cache. */ + def createUncached( + topic: String, + partition: Int, + kafkaParams: ju.Map[String, Object]): CachedKafkaConsumer = { + new CachedKafkaConsumer(new TopicPartition(topic, partition), kafkaParams) + } + private def reportDataLoss0( failOnDataLoss: Boolean, finalMessage: String, diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala index 2696d6f089d2..3e65949a6fd1 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala @@ -95,8 +95,10 @@ private[kafka010] class KafkaOffsetReader( * Closes the connection to Kafka, and cleans up state. */ def close(): Unit = { - consumer.close() - kafkaReaderThread.shutdownNow() + runUninterruptibly { + consumer.close() + } + kafkaReaderThread.shutdown() } /** diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRelation.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRelation.scala index f180bbad6e36..97bd28316932 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRelation.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRelation.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.kafka010 import java.{util => ju} +import java.util.UUID import org.apache.kafka.common.TopicPartition @@ -33,9 +34,9 @@ import org.apache.spark.unsafe.types.UTF8String private[kafka010] class KafkaRelation( override val sqlContext: SQLContext, - kafkaReader: KafkaOffsetReader, - executorKafkaParams: ju.Map[String, Object], + strategy: ConsumerStrategy, sourceOptions: Map[String, String], + specifiedKafkaParams: Map[String, String], failOnDataLoss: Boolean, startingOffsets: KafkaOffsetRangeLimit, endingOffsets: KafkaOffsetRangeLimit) @@ -53,9 +54,27 @@ private[kafka010] class KafkaRelation( override def schema: StructType = KafkaOffsetReader.kafkaSchema override def buildScan(): RDD[Row] = { + // Each running query should use its own group id. Otherwise, the query may be only assigned + // partial data since Kafka will assign partitions to multiple consumers having the same group + // id. Hence, we should generate a unique id for each query. + val uniqueGroupId = s"spark-kafka-relation-${UUID.randomUUID}" + + val kafkaOffsetReader = new KafkaOffsetReader( + strategy, + KafkaSourceProvider.kafkaParamsForDriver(specifiedKafkaParams), + sourceOptions, + driverGroupIdPrefix = s"$uniqueGroupId-driver") + // Leverage the KafkaReader to obtain the relevant partition offsets - val fromPartitionOffsets = getPartitionOffsets(startingOffsets) - val untilPartitionOffsets = getPartitionOffsets(endingOffsets) + val (fromPartitionOffsets, untilPartitionOffsets) = { + try { + (getPartitionOffsets(kafkaOffsetReader, startingOffsets), + getPartitionOffsets(kafkaOffsetReader, endingOffsets)) + } finally { + kafkaOffsetReader.close() + } + } + // Obtain topicPartitions in both from and until partition offset, ignoring // topic partitions that were added and/or deleted between the two above calls. if (fromPartitionOffsets.keySet != untilPartitionOffsets.keySet) { @@ -82,6 +101,8 @@ private[kafka010] class KafkaRelation( offsetRanges.sortBy(_.topicPartition.toString).mkString(", ")) // Create an RDD that reads from Kafka and get the (key, value) pair as byte arrays. + val executorKafkaParams = + KafkaSourceProvider.kafkaParamsForExecutors(specifiedKafkaParams, uniqueGroupId) val rdd = new KafkaSourceRDD( sqlContext.sparkContext, executorKafkaParams, offsetRanges, pollTimeoutMs, failOnDataLoss, reuseKafkaConsumer = false).map { cr => @@ -98,6 +119,7 @@ private[kafka010] class KafkaRelation( } private def getPartitionOffsets( + kafkaReader: KafkaOffsetReader, kafkaOffsets: KafkaOffsetRangeLimit): Map[TopicPartition, Long] = { def validateTopicPartitions(partitions: Set[TopicPartition], partitionOffsets: Map[TopicPartition, Long]): Map[TopicPartition, Long] = { diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala index ab1ce347cbe3..3cb4d8cad12c 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala @@ -111,10 +111,6 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister sqlContext: SQLContext, parameters: Map[String, String]): BaseRelation = { validateBatchOptions(parameters) - // Each running query should use its own group id. Otherwise, the query may be only assigned - // partial data since Kafka will assign partitions to multiple consumers having the same group - // id. Hence, we should generate a unique id for each query. - val uniqueGroupId = s"spark-kafka-relation-${UUID.randomUUID}" val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) } val specifiedKafkaParams = parameters @@ -131,20 +127,14 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister ENDING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit) assert(endingRelationOffsets != EarliestOffsetRangeLimit) - val kafkaOffsetReader = new KafkaOffsetReader( - strategy(caseInsensitiveParams), - kafkaParamsForDriver(specifiedKafkaParams), - parameters, - driverGroupIdPrefix = s"$uniqueGroupId-driver") - new KafkaRelation( sqlContext, - kafkaOffsetReader, - kafkaParamsForExecutors(specifiedKafkaParams, uniqueGroupId), - parameters, - failOnDataLoss(caseInsensitiveParams), - startingRelationOffsets, - endingRelationOffsets) + strategy(caseInsensitiveParams), + sourceOptions = parameters, + specifiedKafkaParams = specifiedKafkaParams, + failOnDataLoss = failOnDataLoss(caseInsensitiveParams), + startingOffsets = startingRelationOffsets, + endingOffsets = endingRelationOffsets) } override def createSink( @@ -213,46 +203,6 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG -> classOf[ByteArraySerializer].getName) } - private def kafkaParamsForDriver(specifiedKafkaParams: Map[String, String]) = - ConfigUpdater("source", specifiedKafkaParams) - .set(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, deserClassName) - .set(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, deserClassName) - - // Set to "earliest" to avoid exceptions. However, KafkaSource will fetch the initial - // offsets by itself instead of counting on KafkaConsumer. - .set(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "earliest") - - // So that consumers in the driver does not commit offsets unnecessarily - .set(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG, "false") - - // So that the driver does not pull too much data - .set(ConsumerConfig.MAX_POLL_RECORDS_CONFIG, new java.lang.Integer(1)) - - // If buffer config is not set, set it to reasonable value to work around - // buffer issues (see KAFKA-3135) - .setIfUnset(ConsumerConfig.RECEIVE_BUFFER_CONFIG, 65536: java.lang.Integer) - .build() - - private def kafkaParamsForExecutors( - specifiedKafkaParams: Map[String, String], uniqueGroupId: String) = - ConfigUpdater("executor", specifiedKafkaParams) - .set(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, deserClassName) - .set(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, deserClassName) - - // Make sure executors do only what the driver tells them. - .set(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "none") - - // So that consumers in executors do not mess with any existing group id - .set(ConsumerConfig.GROUP_ID_CONFIG, s"$uniqueGroupId-executor") - - // So that consumers in executors does not commit offsets unnecessarily - .set(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG, "false") - - // If buffer config is not set, set it to reasonable value to work around - // buffer issues (see KAFKA-3135) - .setIfUnset(ConsumerConfig.RECEIVE_BUFFER_CONFIG, 65536: java.lang.Integer) - .build() - private def strategy(caseInsensitiveParams: Map[String, String]) = caseInsensitiveParams.find(x => STRATEGY_OPTION_KEYS.contains(x._1)).get match { case ("assign", value) => @@ -414,30 +364,9 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister logWarning("maxOffsetsPerTrigger option ignored in batch queries") } } - - /** Class to conveniently update Kafka config params, while logging the changes */ - private case class ConfigUpdater(module: String, kafkaParams: Map[String, String]) { - private val map = new ju.HashMap[String, Object](kafkaParams.asJava) - - def set(key: String, value: Object): this.type = { - map.put(key, value) - logInfo(s"$module: Set $key to $value, earlier value: ${kafkaParams.getOrElse(key, "")}") - this - } - - def setIfUnset(key: String, value: Object): ConfigUpdater = { - if (!map.containsKey(key)) { - map.put(key, value) - logInfo(s"$module: Set $key to $value") - } - this - } - - def build(): ju.Map[String, Object] = map - } } -private[kafka010] object KafkaSourceProvider { +private[kafka010] object KafkaSourceProvider extends Logging { private val STRATEGY_OPTION_KEYS = Set("subscribe", "subscribepattern", "assign") private[kafka010] val STARTING_OFFSETS_OPTION_KEY = "startingoffsets" private[kafka010] val ENDING_OFFSETS_OPTION_KEY = "endingoffsets" @@ -459,4 +388,66 @@ private[kafka010] object KafkaSourceProvider { case None => defaultOffsets } } + + def kafkaParamsForDriver(specifiedKafkaParams: Map[String, String]): ju.Map[String, Object] = + ConfigUpdater("source", specifiedKafkaParams) + .set(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, deserClassName) + .set(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, deserClassName) + + // Set to "earliest" to avoid exceptions. However, KafkaSource will fetch the initial + // offsets by itself instead of counting on KafkaConsumer. + .set(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "earliest") + + // So that consumers in the driver does not commit offsets unnecessarily + .set(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG, "false") + + // So that the driver does not pull too much data + .set(ConsumerConfig.MAX_POLL_RECORDS_CONFIG, new java.lang.Integer(1)) + + // If buffer config is not set, set it to reasonable value to work around + // buffer issues (see KAFKA-3135) + .setIfUnset(ConsumerConfig.RECEIVE_BUFFER_CONFIG, 65536: java.lang.Integer) + .build() + + def kafkaParamsForExecutors( + specifiedKafkaParams: Map[String, String], + uniqueGroupId: String): ju.Map[String, Object] = + ConfigUpdater("executor", specifiedKafkaParams) + .set(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, deserClassName) + .set(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, deserClassName) + + // Make sure executors do only what the driver tells them. + .set(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "none") + + // So that consumers in executors do not mess with any existing group id + .set(ConsumerConfig.GROUP_ID_CONFIG, s"$uniqueGroupId-executor") + + // So that consumers in executors does not commit offsets unnecessarily + .set(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG, "false") + + // If buffer config is not set, set it to reasonable value to work around + // buffer issues (see KAFKA-3135) + .setIfUnset(ConsumerConfig.RECEIVE_BUFFER_CONFIG, 65536: java.lang.Integer) + .build() + + /** Class to conveniently update Kafka config params, while logging the changes */ + private case class ConfigUpdater(module: String, kafkaParams: Map[String, String]) { + private val map = new ju.HashMap[String, Object](kafkaParams.asJava) + + def set(key: String, value: Object): this.type = { + map.put(key, value) + logDebug(s"$module: Set $key to $value, earlier value: ${kafkaParams.getOrElse(key, "")}") + this + } + + def setIfUnset(key: String, value: Object): ConfigUpdater = { + if (!map.containsKey(key)) { + map.put(key, value) + logDebug(s"$module: Set $key to $value") + } + this + } + + def build(): ju.Map[String, Object] = map + } } diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceRDD.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceRDD.scala index 6fb3473eb75f..9d9e2aaba807 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceRDD.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceRDD.scala @@ -125,16 +125,15 @@ private[kafka010] class KafkaSourceRDD( context: TaskContext): Iterator[ConsumerRecord[Array[Byte], Array[Byte]]] = { val sourcePartition = thePart.asInstanceOf[KafkaSourceRDDPartition] val topic = sourcePartition.offsetRange.topic - if (!reuseKafkaConsumer) { - // if we can't reuse CachedKafkaConsumers, let's reset the groupId to something unique - // to each task (i.e., append the task's unique partition id), because we will have - // multiple tasks (e.g., in the case of union) reading from the same topic partitions - val old = executorKafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG).asInstanceOf[String] - val id = TaskContext.getPartitionId() - executorKafkaParams.put(ConsumerConfig.GROUP_ID_CONFIG, old + "-" + id) - } val kafkaPartition = sourcePartition.offsetRange.partition - val consumer = CachedKafkaConsumer.getOrCreate(topic, kafkaPartition, executorKafkaParams) + val consumer = + if (!reuseKafkaConsumer) { + // If we can't reuse CachedKafkaConsumers, creating a new CachedKafkaConsumer. As here we + // uses `assign`, we don't need to worry about the "group.id" conflicts. + CachedKafkaConsumer.createUncached(topic, kafkaPartition, executorKafkaParams) + } else { + CachedKafkaConsumer.getOrCreate(topic, kafkaPartition, executorKafkaParams) + } val range = resolveRange(consumer, sourcePartition.offsetRange) assert( range.fromOffset <= range.untilOffset, @@ -170,7 +169,7 @@ private[kafka010] class KafkaSourceRDD( override protected def close(): Unit = { if (!reuseKafkaConsumer) { // Don't forget to close non-reuse KafkaConsumers. You may take down your cluster! - CachedKafkaConsumer.removeKafkaConsumer(topic, kafkaPartition, executorKafkaParams) + consumer.close() } else { // Indicate that we're no longer using this consumer CachedKafkaConsumer.releaseKafkaConsumer(topic, kafkaPartition, executorKafkaParams) diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaRDD.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaRDD.scala index 4c6e2ce87e29..62cdf5b1134e 100644 --- a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaRDD.scala +++ b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaRDD.scala @@ -199,7 +199,7 @@ private[spark] class KafkaRDD[K, V]( val consumer = if (useConsumerCache) { CachedKafkaConsumer.init(cacheInitialCapacity, cacheMaxCapacity, cacheLoadFactor) - if (context.attemptNumber > 1) { + if (context.attemptNumber >= 1) { // just in case the prior attempt failures were cache related CachedKafkaConsumer.remove(groupId, part.topic, part.partition) } From b90bf520fd7b979a90d1377cfc2ee7f0bf82c705 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 27 Apr 2017 19:38:14 -0700 Subject: [PATCH 371/512] [SPARK-12837][CORE] Do not send the name of internal accumulator to executor side ## What changes were proposed in this pull request? When sending accumulator updates back to driver, the network overhead is pretty big as there are a lot of accumulators, e.g. `TaskMetrics` will send about 20 accumulators everytime, there may be a lot of `SQLMetric` if the query plan is complicated. Therefore, it's critical to reduce the size of serialized accumulator. A simple way is to not send the name of internal accumulators to executor side, as it's unnecessary. When executor sends accumulator updates back to driver, we can look up the accumulator name in `AccumulatorContext` easily. Note that, we still need to send names of normal accumulators, as the user code run at executor side may rely on accumulator names. In the future, we should reimplement `TaskMetrics` to not rely on accumulators and use custom serialization. Tried on the example in https://issues.apache.org/jira/browse/SPARK-12837, the size of serialized accumulator has been cut down by about 40%. ## How was this patch tested? existing tests. Author: Wenchen Fan Closes #17596 from cloud-fan/oom. --- .../apache/spark/executor/TaskMetrics.scala | 29 ++++++------- .../org/apache/spark/scheduler/Task.scala | 13 +++--- .../org/apache/spark/util/AccumulatorV2.scala | 28 +++++++------ .../spark/scheduler/TaskContextSuite.scala | 2 +- .../ui/jobs/JobProgressListenerSuite.scala | 2 +- .../apache/spark/util/JsonProtocolSuite.scala | 2 +- .../SpecificParquetRecordReaderBase.java | 12 +++--- .../parquet/ParquetFilterSuite.scala | 42 +++++++++++++++---- 8 files changed, 76 insertions(+), 54 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala index dfd2f818acda..a3ce3d1ccc5e 100644 --- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala @@ -251,13 +251,10 @@ class TaskMetrics private[spark] () extends Serializable { private[spark] def accumulators(): Seq[AccumulatorV2[_, _]] = internalAccums ++ externalAccums - /** - * Looks for a registered accumulator by accumulator name. - */ - private[spark] def lookForAccumulatorByName(name: String): Option[AccumulatorV2[_, _]] = { - accumulators.find { acc => - acc.name.isDefined && acc.name.get == name - } + private[spark] def nonZeroInternalAccums(): Seq[AccumulatorV2[_, _]] = { + // RESULT_SIZE accumulator is always zero at executor, we need to send it back as its + // value will be updated at driver side. + internalAccums.filter(a => !a.isZero || a == _resultSize) } } @@ -308,16 +305,16 @@ private[spark] object TaskMetrics extends Logging { */ def fromAccumulators(accums: Seq[AccumulatorV2[_, _]]): TaskMetrics = { val tm = new TaskMetrics - val (internalAccums, externalAccums) = - accums.partition(a => a.name.isDefined && tm.nameToAccums.contains(a.name.get)) - - internalAccums.foreach { acc => - val tmAcc = tm.nameToAccums(acc.name.get).asInstanceOf[AccumulatorV2[Any, Any]] - tmAcc.metadata = acc.metadata - tmAcc.merge(acc.asInstanceOf[AccumulatorV2[Any, Any]]) + for (acc <- accums) { + val name = acc.name + if (name.isDefined && tm.nameToAccums.contains(name.get)) { + val tmAcc = tm.nameToAccums(name.get).asInstanceOf[AccumulatorV2[Any, Any]] + tmAcc.metadata = acc.metadata + tmAcc.merge(acc.asInstanceOf[AccumulatorV2[Any, Any]]) + } else { + tm.externalAccums += acc + } } - - tm.externalAccums ++= externalAccums tm } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index 7fd2918960cd..5c337b992c84 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -182,14 +182,11 @@ private[spark] abstract class Task[T]( */ def collectAccumulatorUpdates(taskFailed: Boolean = false): Seq[AccumulatorV2[_, _]] = { if (context != null) { - context.taskMetrics.internalAccums.filter { a => - // RESULT_SIZE accumulator is always zero at executor, we need to send it back as its - // value will be updated at driver side. - // Note: internal accumulators representing task metrics always count failed values - !a.isZero || a.name == Some(InternalAccumulator.RESULT_SIZE) - // zero value external accumulators may still be useful, e.g. SQLMetrics, we should not filter - // them out. - } ++ context.taskMetrics.externalAccums.filter(a => !taskFailed || a.countFailedValues) + // Note: internal accumulators representing task metrics always count failed values + context.taskMetrics.nonZeroInternalAccums() ++ + // zero value external accumulators may still be useful, e.g. SQLMetrics, we should not + // filter them out. + context.taskMetrics.externalAccums.filter(a => !taskFailed || a.countFailedValues) } else { Seq.empty } diff --git a/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala b/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala index 7479de55140e..a65ec75cc5db 100644 --- a/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala +++ b/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala @@ -84,8 +84,12 @@ abstract class AccumulatorV2[IN, OUT] extends Serializable { * Returns the name of this accumulator, can only be called after registration. */ final def name: Option[String] = { - assertMetadataNotNull() - metadata.name + if (atDriverSide) { + AccumulatorContext.get(id).flatMap(_.metadata.name) + } else { + assertMetadataNotNull() + metadata.name + } } /** @@ -161,7 +165,15 @@ abstract class AccumulatorV2[IN, OUT] extends Serializable { } val copyAcc = copyAndReset() assert(copyAcc.isZero, "copyAndReset must return a zero value copy") - copyAcc.metadata = metadata + val isInternalAcc = + (name.isDefined && name.get.startsWith(InternalAccumulator.METRICS_PREFIX)) || + getClass.getSimpleName == "SQLMetric" + if (isInternalAcc) { + // Do not serialize the name of internal accumulator and send it to executor. + copyAcc.metadata = metadata.copy(name = None) + } else { + copyAcc.metadata = metadata + } copyAcc } else { this @@ -263,16 +275,6 @@ private[spark] object AccumulatorContext { originals.clear() } - /** - * Looks for a registered accumulator by accumulator name. - */ - private[spark] def lookForAccumulatorByName(name: String): Option[AccumulatorV2[_, _]] = { - originals.values().asScala.find { ref => - val acc = ref.get - acc != null && acc.name.isDefined && acc.name.get == name - }.map(_.get) - } - // Identifier for distinguishing SQL metrics from other accumulators private[spark] val SQL_ACCUM_IDENTIFIER = "sql" } diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala index 8f576daa77d1..b22da565d86e 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala @@ -198,7 +198,7 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark sc = new SparkContext("local", "test") // Create a dummy task. We won't end up running this; we just want to collect // accumulator updates from it. - val taskMetrics = TaskMetrics.empty + val taskMetrics = TaskMetrics.registered val task = new Task[Int](0, 0, 0) { context = new TaskContextImpl(0, 0, 0L, 0, new TaskMemoryManager(SparkEnv.get.memoryManager, 0L), diff --git a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala index 93964a2d5674..48be3be81755 100644 --- a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala @@ -293,7 +293,7 @@ class JobProgressListenerSuite extends SparkFunSuite with LocalSparkContext with val execId = "exe-1" def makeTaskMetrics(base: Int): TaskMetrics = { - val taskMetrics = TaskMetrics.empty + val taskMetrics = TaskMetrics.registered val shuffleReadMetrics = taskMetrics.createTempShuffleReadMetrics() val shuffleWriteMetrics = taskMetrics.shuffleWriteMetrics val inputMetrics = taskMetrics.inputMetrics diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index a64dbeae4729..a77c8e3cab4e 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -830,7 +830,7 @@ private[spark] object JsonProtocolSuite extends Assertions { hasHadoopInput: Boolean, hasOutput: Boolean, hasRecords: Boolean = true) = { - val t = TaskMetrics.empty + val t = TaskMetrics.registered // Set CPU times same as wall times for testing purpose t.setExecutorDeserializeTime(a) t.setExecutorDeserializeCpuTime(a) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java index eb97118872ea..0bab321a657d 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java @@ -153,14 +153,14 @@ public void initialize(InputSplit inputSplit, TaskAttemptContext taskAttemptCont } // For test purpose. - // If the predefined accumulator exists, the row group number to read will be updated - // to the accumulator. So we can check if the row groups are filtered or not in test case. + // If the last external accumulator is `NumRowGroupsAccumulator`, the row group number to read + // will be updated to the accumulator. So we can check if the row groups are filtered or not + // in test case. TaskContext taskContext = TaskContext$.MODULE$.get(); if (taskContext != null) { - Option> accu = taskContext.taskMetrics() - .lookForAccumulatorByName("numRowGroups"); - if (accu.isDefined()) { - ((LongAccumulator)accu.get()).add((long)blocks.size()); + Option> accu = taskContext.taskMetrics().externalAccums().lastOption(); + if (accu.isDefined() && accu.get().getClass().getSimpleName().equals("NumRowGroupsAcc")) { + ((AccumulatorV2)accu.get()).add(blocks.size()); } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index 9a3328fcecee..dd53b561326f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ -import org.apache.spark.util.{AccumulatorContext, LongAccumulator} +import org.apache.spark.util.{AccumulatorContext, AccumulatorV2} /** * A test suite that tests Parquet filter2 API based filter pushdown optimization. @@ -499,18 +499,20 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex val path = s"${dir.getCanonicalPath}/table" (1 to 1024).map(i => (101, i)).toDF("a", "b").write.parquet(path) - Seq(("true", (x: Long) => x == 0), ("false", (x: Long) => x > 0)).map { case (push, func) => - withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> push) { - val accu = new LongAccumulator - accu.register(sparkContext, Some("numRowGroups")) + Seq(true, false).foreach { enablePushDown => + withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> enablePushDown.toString) { + val accu = new NumRowGroupsAcc + sparkContext.register(accu) val df = spark.read.parquet(path).filter("a < 100") df.foreachPartition(_.foreach(v => accu.add(0))) df.collect - val numRowGroups = AccumulatorContext.lookForAccumulatorByName("numRowGroups") - assert(numRowGroups.isDefined) - assert(func(numRowGroups.get.asInstanceOf[LongAccumulator].value)) + if (enablePushDown) { + assert(accu.value == 0) + } else { + assert(accu.value > 0) + } AccumulatorContext.remove(accu.id) } } @@ -537,3 +539,27 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex } } } + +class NumRowGroupsAcc extends AccumulatorV2[Integer, Integer] { + private var _sum = 0 + + override def isZero: Boolean = _sum == 0 + + override def copy(): AccumulatorV2[Integer, Integer] = { + val acc = new NumRowGroupsAcc() + acc._sum = _sum + acc + } + + override def reset(): Unit = _sum = 0 + + override def add(v: Integer): Unit = _sum += v + + override def merge(other: AccumulatorV2[Integer, Integer]): Unit = other match { + case a: NumRowGroupsAcc => _sum += a._sum + case _ => throw new UnsupportedOperationException( + s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}") + } + + override def value: Integer = _sum +} From 7fe8249793bd3eed4fa67cb4a210264a80520786 Mon Sep 17 00:00:00 2001 From: wangmiao1981 Date: Thu, 27 Apr 2017 22:29:47 -0700 Subject: [PATCH 372/512] [SPARKR][DOC] Document LinearSVC in R programming guide ## What changes were proposed in this pull request? add link to svmLinear in the SparkR programming document. ## How was this patch tested? Build doc manually and click the link to the document. It looks good. Author: wangmiao1981 Closes #17797 from wangmiao1981/doc. --- docs/sparkr.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/sparkr.md b/docs/sparkr.md index c3336ac2ce86..c85cfd45c456 100644 --- a/docs/sparkr.md +++ b/docs/sparkr.md @@ -482,6 +482,7 @@ SparkR supports the following machine learning algorithms currently: * [`spark.logit`](api/R/spark.logit.html): [`Logistic Regression`](ml-classification-regression.html#logistic-regression) * [`spark.mlp`](api/R/spark.mlp.html): [`Multilayer Perceptron (MLP)`](ml-classification-regression.html#multilayer-perceptron-classifier) * [`spark.naiveBayes`](api/R/spark.naiveBayes.html): [`Naive Bayes`](ml-classification-regression.html#naive-bayes) +* [`spark.svmLinear`](api/R/spark.svmLinear.html): [`Linear Support Vector Machine`](ml-classification-regression.html#linear-support-vector-machine) #### Regression From e3c816043389e227db5e7a328c7c554209b4f394 Mon Sep 17 00:00:00 2001 From: Xiao Li Date: Fri, 28 Apr 2017 14:16:40 +0800 Subject: [PATCH 373/512] [SPARK-20476][SQL] Block users to create a table that use commas in the column names ### What changes were proposed in this pull request? ```SQL hive> create table t1(`a,` string); OK Time taken: 1.399 seconds hive> create table t2(`a,` string, b string); FAILED: Execution Error, return code 1 from org.apache.hadoop.hive.ql.exec.DDLTask. java.lang.RuntimeException: MetaException(message:org.apache.hadoop.hive.serde2.SerDeException org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe: columns has 3 elements while columns.types has 2 elements!) hive> create table t2(`a,` string, b string) stored as parquet; FAILED: Execution Error, return code 1 from org.apache.hadoop.hive.ql.exec.DDLTask. java.lang.IllegalArgumentException: ParquetHiveSerde initialization failed. Number of column name and column type differs. columnNames = [a, , b], columnTypes = [string, string] ``` It has a bug in Hive metastore. When users do not provide alias name in the SELECT query, we call `toPrettySQL` to generate the alias name. For example, the string `get_json_object(jstring, '$.f1')` will be the alias name for the function call in the statement ```SQL SELECT key, get_json_object(jstring, '$.f1') FROM tempView ``` Above is not an issue for the SELECT query statements. However, for CTAS, we hit the issue due to a bug in Hive metastore. Hive metastore does not like the column names containing commas and returned a confusing error message, like: ``` 17/04/26 23:12:56 ERROR [hive.log(397) -- main]: error in initSerDe: org.apache.hadoop.hive.serde2.SerDeException org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe: columns has 2 elements while columns.types has 1 elements! org.apache.hadoop.hive.serde2.SerDeException: org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe: columns has 2 elements while columns.types has 1 elements! ``` Thus, this PR is to block users to create a table in Hive metastore when the table table has a column containing commas in the name. ### How was this patch tested? Added a test case Author: Xiao Li Closes #17781 from gatorsmile/blockIllegalColumnNames. --- .../spark/sql/hive/HiveExternalCatalog.scala | 18 ++++++++++++++ .../sql/hive/execution/SQLQuerySuite.scala | 24 +++++++++++++++++++ 2 files changed, 42 insertions(+) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index 71e33c46b9ae..ba48facff293 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -137,6 +137,22 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat } } + /** + * Checks the validity of column names. Hive metastore disallows the table to use comma in + * data column names. Partition columns do not have such a restriction. Views do not have such + * a restriction. + */ + private def verifyColumnNames(table: CatalogTable): Unit = { + if (table.tableType != VIEW) { + table.dataSchema.map(_.name).foreach { colName => + if (colName.contains(",")) { + throw new AnalysisException("Cannot create a table having a column whose name contains " + + s"commas in Hive metastore. Table: ${table.identifier}; Column: $colName") + } + } + } + } + // -------------------------------------------------------------------------- // Databases // -------------------------------------------------------------------------- @@ -202,6 +218,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat val table = tableDefinition.identifier.table requireDbExists(db) verifyTableProperties(tableDefinition) + verifyColumnNames(tableDefinition) if (tableExists(db, table) && !ignoreIfExists) { throw new TableAlreadyExistsException(db = db, table = table) @@ -614,6 +631,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat requireTableExists(db, table) val rawTable = getRawTable(db, table) val withNewSchema = rawTable.copy(schema = schema) + verifyColumnNames(withNewSchema) // Add table metadata such as table schema, partition columns, etc. to table properties. val updatedTable = withNewSchema.copy( properties = withNewSchema.properties ++ tableMetaToTableProps(withNewSchema)) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 75f3744ff35b..c944f28d10ef 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -1976,6 +1976,30 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } } + test("Auto alias construction of get_json_object") { + val df = Seq(("1", """{"f1": "value1", "f5": 5.23}""")).toDF("key", "jstring") + val expectedMsg = "Cannot create a table having a column whose name contains commas " + + "in Hive metastore. Table: `default`.`t`; Column: get_json_object(jstring, $.f1)" + + withTable("t") { + val e = intercept[AnalysisException] { + df.select($"key", functions.get_json_object($"jstring", "$.f1")) + .write.format("hive").saveAsTable("t") + }.getMessage + assert(e.contains(expectedMsg)) + } + + withTempView("tempView") { + withTable("t") { + df.createTempView("tempView") + val e = intercept[AnalysisException] { + sql("CREATE TABLE t AS SELECT key, get_json_object(jstring, '$.f1') FROM tempView") + }.getMessage + assert(e.contains(expectedMsg)) + } + } + } + test("SPARK-19912 String literals should be escaped for Hive metastore partition pruning") { withTable("spark_19912") { Seq( From 59e3a564448777657125b6f65057ed20d0162d13 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Fri, 28 Apr 2017 14:41:53 +0800 Subject: [PATCH 374/512] [SPARK-14471][SQL] Aliases in SELECT could be used in GROUP BY ## What changes were proposed in this pull request? This pr added a new rule in `Analyzer` to resolve aliases in `GROUP BY`. The current master throws an exception if `GROUP BY` clauses have aliases in `SELECT`; ``` scala> spark.sql("select a a1, a1 + 1 as b, count(1) from t group by a1") org.apache.spark.sql.AnalysisException: cannot resolve '`a1`' given input columns: [a]; line 1 pos 51; 'Aggregate ['a1], [a#83L AS a1#87L, ('a1 + 1) AS b#88, count(1) AS count(1)#90L] +- SubqueryAlias t +- Project [id#80L AS a#83L] +- Range (0, 10, step=1, splits=Some(8)) at org.apache.spark.sql.catalyst.analysis.package$AnalysisErrorAt.failAnalysis(package.scala:42) at org.apache.spark.sql.catalyst.analysis.CheckAnalysis$$anonfun$checkAnalysis$1$$anonfun$apply$2.applyOrElse(CheckAnalysis.scala:77) at org.apache.spark.sql.catalyst.analysis.CheckAnalysis$$anonfun$checkAnalysis$1$$anonfun$apply$2.applyOrElse(CheckAnalysis.scala:74) at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$transformUp$1.apply(TreeNode.scala:289) ``` ## How was this patch tested? Added tests in `SQLQuerySuite` and `SQLQueryTestSuite`. Author: Takeshi Yamamuro Closes #17191 from maropu/SPARK-14471. --- .../sql/catalyst/analysis/Analyzer.scala | 71 ++++++++++++------- .../apache/spark/sql/internal/SQLConf.scala | 8 +++ .../sql-tests/inputs/group-by-ordinal.sql | 3 + .../resources/sql-tests/inputs/group-by.sql | 18 +++++ .../results/group-by-ordinal.sql.out | 22 ++++-- .../sql-tests/results/group-by.sql.out | 66 ++++++++++++++++- 6 files changed, 156 insertions(+), 32 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index dcadbbc90f43..72e7d5dd3638 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -136,6 +136,7 @@ class Analyzer( ResolveGroupingAnalytics :: ResolvePivot :: ResolveOrdinalInOrderByAndGroupBy :: + ResolveAggAliasInGroupBy :: ResolveMissingReferences :: ExtractGenerator :: ResolveGenerate :: @@ -172,7 +173,7 @@ class Analyzer( * Analyze cte definitions and substitute child plan with analyzed cte definitions. */ object CTESubstitution extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case With(child, relations) => substituteCTE(child, relations.foldLeft(Seq.empty[(String, LogicalPlan)]) { case (resolved, (name, relation)) => @@ -200,7 +201,7 @@ class Analyzer( * Substitute child plan with WindowSpecDefinitions. */ object WindowsSubstitution extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { // Lookup WindowSpecDefinitions. This rule works with unresolved children. case WithWindowDefinition(windowDefinitions, child) => child.transform { @@ -242,7 +243,7 @@ class Analyzer( private def hasUnresolvedAlias(exprs: Seq[NamedExpression]) = exprs.exists(_.find(_.isInstanceOf[UnresolvedAlias]).isDefined) - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case Aggregate(groups, aggs, child) if child.resolved && hasUnresolvedAlias(aggs) => Aggregate(groups, assignAliases(aggs), child) @@ -614,7 +615,7 @@ class Analyzer( case _ => plan } - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case i @ InsertIntoTable(u: UnresolvedRelation, parts, child, _, _) if child.resolved => EliminateSubqueryAliases(lookupTableFromCatalog(u)) match { case v: View => @@ -786,7 +787,7 @@ class Analyzer( } } - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case p: LogicalPlan if !p.childrenResolved => p // If the projection list contains Stars, expand it. @@ -844,11 +845,10 @@ class Analyzer( case q: LogicalPlan => logTrace(s"Attempting to resolve ${q.simpleString}") - q transformExpressionsUp { + q.transformExpressionsUp { case u @ UnresolvedAttribute(nameParts) => - // Leave unchanged if resolution fails. Hopefully will be resolved next round. - val result = - withPosition(u) { q.resolveChildren(nameParts, resolver).getOrElse(u) } + // Leave unchanged if resolution fails. Hopefully will be resolved next round. + val result = withPosition(u) { q.resolveChildren(nameParts, resolver).getOrElse(u) } logDebug(s"Resolving $u to $result") result case UnresolvedExtractValue(child, fieldExpr) if child.resolved => @@ -961,7 +961,7 @@ class Analyzer( * have no effect on the results. */ object ResolveOrdinalInOrderByAndGroupBy extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case p if !p.childrenResolved => p // Replace the index with the related attribute for ORDER BY, // which is a 1-base position of the projection list. @@ -997,6 +997,27 @@ class Analyzer( } } + /** + * Replace unresolved expressions in grouping keys with resolved ones in SELECT clauses. + * This rule is expected to run after [[ResolveReferences]] applied. + */ + object ResolveAggAliasInGroupBy extends Rule[LogicalPlan] { + + override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { + case agg @ Aggregate(groups, aggs, child) + if conf.groupByAliases && child.resolved && aggs.forall(_.resolved) && + groups.exists(_.isInstanceOf[UnresolvedAttribute]) => + // This is a strict check though, we put this to apply the rule only in alias expressions + def notResolvableByChild(attrName: String): Boolean = + !child.output.exists(a => resolver(a.name, attrName)) + agg.copy(groupingExpressions = groups.map { + case u: UnresolvedAttribute if notResolvableByChild(u.name) => + aggs.find(ne => resolver(ne.name, u.name)).getOrElse(u) + case e => e + }) + } + } + /** * In many dialects of SQL it is valid to sort by attributes that are not present in the SELECT * clause. This rule detects such queries and adds the required attributes to the original @@ -1006,7 +1027,7 @@ class Analyzer( * The HAVING clause could also used a grouping columns that is not presented in the SELECT. */ object ResolveMissingReferences extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { // Skip sort with aggregate. This will be handled in ResolveAggregateFunctions case sa @ Sort(_, _, child: Aggregate) => sa @@ -1130,7 +1151,7 @@ class Analyzer( * Replaces [[UnresolvedFunction]]s with concrete [[Expression]]s. */ object ResolveFunctions extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case q: LogicalPlan => q transformExpressions { case u if !u.childrenResolved => u // Skip until children are resolved. @@ -1469,7 +1490,7 @@ class Analyzer( /** * Resolve and rewrite all subqueries in an operator tree.. */ - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { // In case of HAVING (a filter after an aggregate) we use both the aggregate and // its child for resolution. case f @ Filter(_, a: Aggregate) if f.childrenResolved => @@ -1484,7 +1505,7 @@ class Analyzer( * Turns projections that contain aggregate expressions into aggregations. */ object GlobalAggregates extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case Project(projectList, child) if containsAggregates(projectList) => Aggregate(Nil, projectList, child) } @@ -1510,7 +1531,7 @@ class Analyzer( * underlying aggregate operator and then projected away after the original operator. */ object ResolveAggregateFunctions extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case filter @ Filter(havingCondition, aggregate @ Aggregate(grouping, originalAggExprs, child)) if aggregate.resolved => @@ -1682,7 +1703,7 @@ class Analyzer( } } - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case Project(projectList, _) if projectList.exists(hasNestedGenerator) => val nestedGenerator = projectList.find(hasNestedGenerator).get throw new AnalysisException("Generators are not supported when it's nested in " + @@ -1740,7 +1761,7 @@ class Analyzer( * that wrap the [[Generator]]. */ object ResolveGenerate extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case g: Generate if !g.child.resolved || !g.generator.resolved => g case g: Generate if !g.resolved => g.copy(generatorOutput = makeGeneratorOutput(g.generator, g.generatorOutput.map(_.name))) @@ -2057,7 +2078,7 @@ class Analyzer( * put them into an inner Project and finally project them away at the outer Project. */ object PullOutNondeterministic extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case p if !p.resolved => p // Skip unresolved nodes. case p: Project => p case f: Filter => f @@ -2102,7 +2123,7 @@ class Analyzer( * and we should return null if the input is null. */ object HandleNullInputsForUDF extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case p if !p.resolved => p // Skip unresolved nodes. case p => p transformExpressionsUp { @@ -2167,7 +2188,7 @@ class Analyzer( * Then apply a Project on a normal Join to eliminate natural or using join. */ object ResolveNaturalAndUsingJoin extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case j @ Join(left, right, UsingJoin(joinType, usingCols), condition) if left.resolved && right.resolved && j.duplicateResolved => commonNaturalJoinProcessing(left, right, joinType, usingCols, None) @@ -2232,7 +2253,7 @@ class Analyzer( * to the given input attributes. */ object ResolveDeserializer extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case p if !p.childrenResolved => p case p if p.resolved => p @@ -2318,7 +2339,7 @@ class Analyzer( * constructed is an inner class. */ object ResolveNewInstance extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case p if !p.childrenResolved => p case p if p.resolved => p @@ -2352,7 +2373,7 @@ class Analyzer( "type of the field in the target object") } - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case p if !p.childrenResolved => p case p if p.resolved => p @@ -2406,7 +2427,7 @@ object CleanupAliases extends Rule[LogicalPlan] { case other => trimAliases(other) } - override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case Project(projectList, child) => val cleanedProjectList = projectList.map(trimNonTopLevelAliases(_).asInstanceOf[NamedExpression]) @@ -2474,7 +2495,7 @@ object TimeWindowing extends Rule[LogicalPlan] { * @return the logical plan that will generate the time windows using the Expand operator, with * the Filter operator for correctness and Project for usability. */ - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case p: LogicalPlan if p.children.size == 1 => val child = p.children.head val windowExpressions = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 2e1798e22b9f..b24419a41edb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -421,6 +421,12 @@ object SQLConf { .booleanConf .createWithDefault(true) + val GROUP_BY_ALIASES = buildConf("spark.sql.groupByAliases") + .doc("When true, aliases in a select list can be used in group by clauses. When false, " + + "an analysis exception is thrown in the case.") + .booleanConf + .createWithDefault(true) + // The output committer class used by data sources. The specified class needs to be a // subclass of org.apache.hadoop.mapreduce.OutputCommitter. val OUTPUT_COMMITTER_CLASS = @@ -1003,6 +1009,8 @@ class SQLConf extends Serializable with Logging { def groupByOrdinal: Boolean = getConf(GROUP_BY_ORDINAL) + def groupByAliases: Boolean = getConf(GROUP_BY_ALIASES) + def crossJoinEnabled: Boolean = getConf(SQLConf.CROSS_JOINS_ENABLED) def sessionLocalTimeZone: String = getConf(SQLConf.SESSION_LOCAL_TIMEZONE) diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-by-ordinal.sql b/sql/core/src/test/resources/sql-tests/inputs/group-by-ordinal.sql index 9c8d851e36e9..6566338f3d4a 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/group-by-ordinal.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/group-by-ordinal.sql @@ -49,6 +49,9 @@ select a, count(a) from (select 1 as a) tmp group by 1 order by 1; -- group by ordinal followed by having select count(a), a from (select 1 as a) tmp group by 2 having a > 0; +-- mixed cases: group-by ordinals and aliases +select a, a AS k, count(b) from data group by k, 1; + -- turn of group by ordinal set spark.sql.groupByOrdinal=false; diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql index 4d0ed4315300..a7994f3beaff 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql @@ -35,3 +35,21 @@ FROM testData; -- Aggregate with foldable input and multiple distinct groups. SELECT COUNT(DISTINCT b), COUNT(DISTINCT b, c) FROM (SELECT 1 AS a, 2 AS b, 3 AS c) GROUP BY a; + +-- Aliases in SELECT could be used in GROUP BY +SELECT a AS k, COUNT(b) FROM testData GROUP BY k; +SELECT a AS k, COUNT(b) FROM testData GROUP BY k HAVING k > 1; + +-- Aggregate functions cannot be used in GROUP BY +SELECT COUNT(b) AS k FROM testData GROUP BY k; + +-- Test data. +CREATE OR REPLACE TEMPORARY VIEW testDataHasSameNameWithAlias AS SELECT * FROM VALUES +(1, 1, 3), (1, 2, 1) AS testDataHasSameNameWithAlias(k, a, v); +SELECT k AS a, COUNT(v) FROM testDataHasSameNameWithAlias GROUP BY a; + +-- turn off group by aliases +set spark.sql.groupByAliases=false; + +-- Check analysis exceptions +SELECT a AS k, COUNT(b) FROM testData GROUP BY k; diff --git a/sql/core/src/test/resources/sql-tests/results/group-by-ordinal.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by-ordinal.sql.out index d03681d0ea59..9ecbe19078dd 100644 --- a/sql/core/src/test/resources/sql-tests/results/group-by-ordinal.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/group-by-ordinal.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 19 +-- Number of queries: 20 -- !query 0 @@ -173,16 +173,26 @@ struct -- !query 17 -set spark.sql.groupByOrdinal=false +select a, a AS k, count(b) from data group by k, 1 -- !query 17 schema -struct +struct -- !query 17 output -spark.sql.groupByOrdinal false +1 1 2 +2 2 2 +3 3 2 -- !query 18 -select sum(b) from data group by -1 +set spark.sql.groupByOrdinal=false -- !query 18 schema -struct +struct -- !query 18 output +spark.sql.groupByOrdinal false + + +-- !query 19 +select sum(b) from data group by -1 +-- !query 19 schema +struct +-- !query 19 output 9 diff --git a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out index 4b87d5161fc0..6bf9dff883c1 100644 --- a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 15 +-- Number of queries: 22 -- !query 0 @@ -139,3 +139,67 @@ SELECT COUNT(DISTINCT b), COUNT(DISTINCT b, c) FROM (SELECT 1 AS a, 2 AS b, 3 AS struct -- !query 14 output 1 1 + + +-- !query 15 +SELECT a AS k, COUNT(b) FROM testData GROUP BY k +-- !query 15 schema +struct +-- !query 15 output +1 2 +2 2 +3 2 +NULL 1 + + +-- !query 16 +SELECT a AS k, COUNT(b) FROM testData GROUP BY k HAVING k > 1 +-- !query 16 schema +struct +-- !query 16 output +2 2 +3 2 + + +-- !query 17 +SELECT COUNT(b) AS k FROM testData GROUP BY k +-- !query 17 schema +struct<> +-- !query 17 output +org.apache.spark.sql.AnalysisException +aggregate functions are not allowed in GROUP BY, but found count(testdata.`b`); + + +-- !query 18 +CREATE OR REPLACE TEMPORARY VIEW testDataHasSameNameWithAlias AS SELECT * FROM VALUES +(1, 1, 3), (1, 2, 1) AS testDataHasSameNameWithAlias(k, a, v) +-- !query 18 schema +struct<> +-- !query 18 output + + + +-- !query 19 +SELECT k AS a, COUNT(v) FROM testDataHasSameNameWithAlias GROUP BY a +-- !query 19 schema +struct<> +-- !query 19 output +org.apache.spark.sql.AnalysisException +expression 'testdatahassamenamewithalias.`k`' is neither present in the group by, nor is it an aggregate function. Add to group by or wrap in first() (or first_value) if you don't care which value you get.; + + +-- !query 20 +set spark.sql.groupByAliases=false +-- !query 20 schema +struct +-- !query 20 output +spark.sql.groupByAliases false + + +-- !query 21 +SELECT a AS k, COUNT(b) FROM testData GROUP BY k +-- !query 21 schema +struct<> +-- !query 21 output +org.apache.spark.sql.AnalysisException +cannot resolve '`k`' given input columns: [a, b]; line 1 pos 47 From 8c911adac56a1b1d95bc19915e0070ce7305257c Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Fri, 28 Apr 2017 08:49:35 +0100 Subject: [PATCH 375/512] [SPARK-20465][CORE] Throws a proper exception when any temp directory could not be got ## What changes were proposed in this pull request? This PR proposes to throw an exception with better message rather than `ArrayIndexOutOfBoundsException` when temp directories could not be created. Running the commands below: ```bash ./bin/spark-shell --conf spark.local.dir=/NONEXISTENT_DIR_ONE,/NONEXISTENT_DIR_TWO ``` produces ... **Before** ``` Exception in thread "main" java.lang.ExceptionInInitializerError ... Caused by: java.lang.ArrayIndexOutOfBoundsException: 0 ... ``` **After** ``` Exception in thread "main" java.lang.ExceptionInInitializerError ... Caused by: java.io.IOException: Failed to get a temp directory under [/NONEXISTENT_DIR_ONE,/NONEXISTENT_DIR_TWO]. ... ``` ## How was this patch tested? Unit tests in `LocalDirsSuite.scala`. Author: hyukjinkwon Closes #17768 from HyukjinKwon/throws-temp-dir-exception. --- .../scala/org/apache/spark/util/Utils.scala | 6 ++++- .../apache/spark/storage/LocalDirsSuite.scala | 23 ++++++++++++++++--- 2 files changed, 25 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index e042badcdd4a..4d37db96dfc3 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -740,7 +740,11 @@ private[spark] object Utils extends Logging { * always return a single directory. */ def getLocalDir(conf: SparkConf): String = { - getOrCreateLocalRootDirs(conf)(0) + getOrCreateLocalRootDirs(conf).headOption.getOrElse { + val configuredLocalDirs = getConfiguredLocalDirs(conf) + throw new IOException( + s"Failed to get a temp directory under [${configuredLocalDirs.mkString(",")}].") + } } private[spark] def isRunningInYarnContainer(conf: SparkConf): Boolean = { diff --git a/core/src/test/scala/org/apache/spark/storage/LocalDirsSuite.scala b/core/src/test/scala/org/apache/spark/storage/LocalDirsSuite.scala index c7074078d8fd..f7b3a2754f0e 100644 --- a/core/src/test/scala/org/apache/spark/storage/LocalDirsSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/LocalDirsSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.storage -import java.io.File +import java.io.{File, IOException} import org.scalatest.BeforeAndAfter @@ -33,9 +33,13 @@ class LocalDirsSuite extends SparkFunSuite with BeforeAndAfter { Utils.clearLocalRootDirs() } + after { + Utils.clearLocalRootDirs() + } + test("Utils.getLocalDir() returns a valid directory, even if some local dirs are missing") { // Regression test for SPARK-2974 - assert(!new File("/NONEXISTENT_DIR").exists()) + assert(!new File("/NONEXISTENT_PATH").exists()) val conf = new SparkConf(false) .set("spark.local.dir", s"/NONEXISTENT_PATH,${System.getProperty("java.io.tmpdir")}") assert(new File(Utils.getLocalDir(conf)).exists()) @@ -43,7 +47,7 @@ class LocalDirsSuite extends SparkFunSuite with BeforeAndAfter { test("SPARK_LOCAL_DIRS override also affects driver") { // Regression test for SPARK-2975 - assert(!new File("/NONEXISTENT_DIR").exists()) + assert(!new File("/NONEXISTENT_PATH").exists()) // spark.local.dir only contains invalid directories, but that's not a problem since // SPARK_LOCAL_DIRS will override it on both the driver and workers: val conf = new SparkConfWithEnv(Map("SPARK_LOCAL_DIRS" -> System.getProperty("java.io.tmpdir"))) @@ -51,4 +55,17 @@ class LocalDirsSuite extends SparkFunSuite with BeforeAndAfter { assert(new File(Utils.getLocalDir(conf)).exists()) } + test("Utils.getLocalDir() throws an exception if any temporary directory cannot be retrieved") { + val path1 = "/NONEXISTENT_PATH_ONE" + val path2 = "/NONEXISTENT_PATH_TWO" + assert(!new File(path1).exists()) + assert(!new File(path2).exists()) + val conf = new SparkConf(false).set("spark.local.dir", s"$path1,$path2") + val message = intercept[IOException] { + Utils.getLocalDir(conf) + }.getMessage + // If any temporary directory could not be retrieved under the given paths above, it should + // throw an exception with the message that includes the paths. + assert(message.contains(s"$path1,$path2")) + } } From 733b81b835f952ab96723c749461d6afc0c71974 Mon Sep 17 00:00:00 2001 From: Bill Chambers Date: Fri, 28 Apr 2017 10:18:31 -0700 Subject: [PATCH 376/512] [SPARK-20496][SS] Bug in KafkaWriter Looks at Unanalyzed Plans ## What changes were proposed in this pull request? We didn't enforce analyzed plans in Spark 2.1 when writing out to Kafka. ## How was this patch tested? New unit test. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Bill Chambers Closes #17804 from anabranch/SPARK-20496-2. --- .../apache/spark/sql/kafka010/KafkaWriter.scala | 4 ++-- .../spark/sql/kafka010/KafkaSinkSuite.scala | 16 ++++++++++++++++ 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala index a637d52c933a..61936e32fd83 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala @@ -47,7 +47,7 @@ private[kafka010] object KafkaWriter extends Logging { queryExecution: QueryExecution, kafkaParameters: ju.Map[String, Object], topic: Option[String] = None): Unit = { - val schema = queryExecution.logical.output + val schema = queryExecution.analyzed.output schema.find(_.name == TOPIC_ATTRIBUTE_NAME).getOrElse( if (topic == None) { throw new AnalysisException(s"topic option required when no " + @@ -84,7 +84,7 @@ private[kafka010] object KafkaWriter extends Logging { queryExecution: QueryExecution, kafkaParameters: ju.Map[String, Object], topic: Option[String] = None): Unit = { - val schema = queryExecution.logical.output + val schema = queryExecution.analyzed.output validateQuery(queryExecution, kafkaParameters, topic) SQLExecution.withNewExecutionId(sparkSession, queryExecution) { queryExecution.toRdd.foreachPartition { iter => diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala index 4bd052d249ec..2ab336c7ac47 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala @@ -28,6 +28,7 @@ import org.apache.spark.SparkException import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.expressions.{AttributeReference, SpecificInternalRow, UnsafeProjection} import org.apache.spark.sql.execution.streaming.MemoryStream +import org.apache.spark.sql.functions._ import org.apache.spark.sql.streaming._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{BinaryType, DataType} @@ -108,6 +109,21 @@ class KafkaSinkSuite extends StreamTest with SharedSQLContext { s"save mode overwrite not allowed for kafka")) } + test("SPARK-20496: batch - enforce analyzed plans") { + val inputEvents = + spark.range(1, 1000) + .select(to_json(struct("*")) as 'value) + + val topic = newTopic() + testUtils.createTopic(topic) + // used to throw UnresolvedException + inputEvents.write + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("topic", topic) + .save() + } + test("streaming - write to kafka with topic field") { val input = MemoryStream[String] val topic = newTopic() From 5d71f3db83138bf50749dcd425ef7365c34bd799 Mon Sep 17 00:00:00 2001 From: Mark Grover Date: Fri, 28 Apr 2017 14:06:57 -0700 Subject: [PATCH 377/512] [SPARK-20514][CORE] Upgrade Jetty to 9.3.11.v20160721 Upgrade Jetty so it can work with Hadoop 3 (alpha 2 release, in particular). Without this change, because of incompatibily between Jetty versions, Spark fails to compile when built against Hadoop 3 ## How was this patch tested? Unit tests being run. Author: Mark Grover Closes #17790 from markgrover/spark-20514. --- core/src/main/scala/org/apache/spark/ui/JettyUtils.scala | 2 +- pom.xml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala index bdbdba578085..edf328b5ae53 100644 --- a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala @@ -29,8 +29,8 @@ import org.eclipse.jetty.client.api.Response import org.eclipse.jetty.proxy.ProxyServlet import org.eclipse.jetty.server._ import org.eclipse.jetty.server.handler._ +import org.eclipse.jetty.server.handler.gzip.GzipHandler import org.eclipse.jetty.servlet._ -import org.eclipse.jetty.servlets.gzip.GzipHandler import org.eclipse.jetty.util.component.LifeCycle import org.eclipse.jetty.util.thread.{QueuedThreadPool, ScheduledExecutorScheduler} import org.json4s.JValue diff --git a/pom.xml b/pom.xml index b6654c1411d2..517ebc5c83fc 100644 --- a/pom.xml +++ b/pom.xml @@ -136,7 +136,7 @@ 10.12.1.1 1.8.2 1.6.0 - 9.2.16.v20160414 + 9.3.11.v20160721 3.1.0 0.8.0 2.4.0 From ebff519c5ead31536e17a5b16cc47c2bf380d55e Mon Sep 17 00:00:00 2001 From: caoxuewen Date: Fri, 28 Apr 2017 14:47:17 -0700 Subject: [PATCH 378/512] [SPARK-20471] Remove AggregateBenchmark testsuite warning: Two level hashmap is disabled but vectorized hashmap is enabled What changes were proposed in this pull request? remove AggregateBenchmark testsuite warning: such as '14:26:33.220 WARN org.apache.spark.sql.execution.aggregate.HashAggregateExec: Two level hashmap is disabled but vectorized hashmap is enabled.' How was this patch tested? unit tests: AggregateBenchmark Modify the 'ignore function for 'test funtion Author: caoxuewen Closes #17771 from heary-cao/AggregateBenchmark. --- .../spark/sql/execution/benchmark/AggregateBenchmark.scala | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AggregateBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AggregateBenchmark.scala index 8a2993bdf4b2..8a798fb44469 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AggregateBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AggregateBenchmark.scala @@ -107,6 +107,7 @@ class AggregateBenchmark extends BenchmarkBase { benchmark.addCase(s"codegen = T hashmap = F", numIters = 3) { iter => sparkSession.conf.set("spark.sql.codegen.wholeStage", "true") sparkSession.conf.set("spark.sql.codegen.aggregate.map.twolevel.enable", "false") + sparkSession.conf.set("spark.sql.codegen.aggregate.map.vectorized.enable", "false") f() } @@ -148,6 +149,7 @@ class AggregateBenchmark extends BenchmarkBase { benchmark.addCase(s"codegen = T hashmap = F", numIters = 3) { iter => sparkSession.conf.set("spark.sql.codegen.wholeStage", value = true) sparkSession.conf.set("spark.sql.codegen.aggregate.map.twolevel.enable", "false") + sparkSession.conf.set("spark.sql.codegen.aggregate.map.vectorized.enable", "false") f() } @@ -187,6 +189,7 @@ class AggregateBenchmark extends BenchmarkBase { benchmark.addCase(s"codegen = T hashmap = F", numIters = 3) { iter => sparkSession.conf.set("spark.sql.codegen.wholeStage", "true") sparkSession.conf.set("spark.sql.codegen.aggregate.map.twolevel.enable", "false") + sparkSession.conf.set("spark.sql.codegen.aggregate.map.vectorized.enable", "false") f() } @@ -225,6 +228,7 @@ class AggregateBenchmark extends BenchmarkBase { benchmark.addCase(s"codegen = T hashmap = F") { iter => sparkSession.conf.set("spark.sql.codegen.wholeStage", "true") sparkSession.conf.set("spark.sql.codegen.aggregate.map.twolevel.enable", "false") + sparkSession.conf.set("spark.sql.codegen.aggregate.map.vectorized.enable", "false") f() } @@ -273,6 +277,7 @@ class AggregateBenchmark extends BenchmarkBase { benchmark.addCase(s"codegen = T hashmap = F") { iter => sparkSession.conf.set("spark.sql.codegen.wholeStage", "true") sparkSession.conf.set("spark.sql.codegen.aggregate.map.twolevel.enable", "false") + sparkSession.conf.set("spark.sql.codegen.aggregate.map.vectorized.enable", "false") f() } From 77bcd77ed5fbd91fe61849cca76a8dffe5e4d6b2 Mon Sep 17 00:00:00 2001 From: Aaditya Ramesh Date: Fri, 28 Apr 2017 15:28:56 -0700 Subject: [PATCH 379/512] [SPARK-19525][CORE] Add RDD checkpoint compression support ## What changes were proposed in this pull request? This PR adds RDD checkpoint compression support and add a new config `spark.checkpoint.compress` to enable/disable it. Credit goes to aramesh117 Closes #17024 ## How was this patch tested? The new unit test. Author: Shixiong Zhu Author: Aaditya Ramesh Closes #17789 from zsxwing/pr17024. --- .../spark/internal/config/package.scala | 6 +++ .../spark/rdd/ReliableCheckpointRDD.scala | 24 ++++++++++- .../org/apache/spark/CheckpointSuite.scala | 41 +++++++++++++++++++ 3 files changed, 69 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 2f0a3064be11..7f7921d56f49 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -272,4 +272,10 @@ package object config { .booleanConf .createWithDefault(false) + private[spark] val CHECKPOINT_COMPRESS = + ConfigBuilder("spark.checkpoint.compress") + .doc("Whether to compress RDD checkpoints. Generally a good idea. Compression will use " + + "spark.io.compression.codec.") + .booleanConf + .createWithDefault(false) } diff --git a/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala index e0a29b48314f..37c67cee55f9 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala @@ -18,6 +18,7 @@ package org.apache.spark.rdd import java.io.{FileNotFoundException, IOException} +import java.util.concurrent.TimeUnit import scala.reflect.ClassTag import scala.util.control.NonFatal @@ -27,6 +28,8 @@ import org.apache.hadoop.fs.Path import org.apache.spark._ import org.apache.spark.broadcast.Broadcast import org.apache.spark.internal.Logging +import org.apache.spark.internal.config.CHECKPOINT_COMPRESS +import org.apache.spark.io.CompressionCodec import org.apache.spark.util.{SerializableConfiguration, Utils} /** @@ -119,6 +122,7 @@ private[spark] object ReliableCheckpointRDD extends Logging { originalRDD: RDD[T], checkpointDir: String, blockSize: Int = -1): ReliableCheckpointRDD[T] = { + val checkpointStartTimeNs = System.nanoTime() val sc = originalRDD.sparkContext @@ -140,6 +144,10 @@ private[spark] object ReliableCheckpointRDD extends Logging { writePartitionerToCheckpointDir(sc, originalRDD.partitioner.get, checkpointDirPath) } + val checkpointDurationMs = + TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - checkpointStartTimeNs) + logInfo(s"Checkpointing took $checkpointDurationMs ms.") + val newRDD = new ReliableCheckpointRDD[T]( sc, checkpointDirPath.toString, originalRDD.partitioner) if (newRDD.partitions.length != originalRDD.partitions.length) { @@ -169,7 +177,12 @@ private[spark] object ReliableCheckpointRDD extends Logging { val bufferSize = env.conf.getInt("spark.buffer.size", 65536) val fileOutputStream = if (blockSize < 0) { - fs.create(tempOutputPath, false, bufferSize) + val fileStream = fs.create(tempOutputPath, false, bufferSize) + if (env.conf.get(CHECKPOINT_COMPRESS)) { + CompressionCodec.createCodec(env.conf).compressedOutputStream(fileStream) + } else { + fileStream + } } else { // This is mainly for testing purpose fs.create(tempOutputPath, false, bufferSize, @@ -273,7 +286,14 @@ private[spark] object ReliableCheckpointRDD extends Logging { val env = SparkEnv.get val fs = path.getFileSystem(broadcastedConf.value.value) val bufferSize = env.conf.getInt("spark.buffer.size", 65536) - val fileInputStream = fs.open(path, bufferSize) + val fileInputStream = { + val fileStream = fs.open(path, bufferSize) + if (env.conf.get(CHECKPOINT_COMPRESS)) { + CompressionCodec.createCodec(env.conf).compressedInputStream(fileStream) + } else { + fileStream + } + } val serializer = env.serializer.newInstance() val deserializeStream = serializer.deserializeStream(fileInputStream) diff --git a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala index b117c7709b46..ee70a3399efe 100644 --- a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala +++ b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala @@ -21,8 +21,10 @@ import java.io.File import scala.reflect.ClassTag +import com.google.common.io.ByteStreams import org.apache.hadoop.fs.Path +import org.apache.spark.io.CompressionCodec import org.apache.spark.rdd._ import org.apache.spark.storage.{BlockId, StorageLevel, TestBlockId} import org.apache.spark.util.Utils @@ -580,3 +582,42 @@ object CheckpointSuite { ).asInstanceOf[RDD[(K, Array[Iterable[V]])]] } } + +class CheckpointCompressionSuite extends SparkFunSuite with LocalSparkContext { + + test("checkpoint compression") { + val checkpointDir = Utils.createTempDir() + try { + val conf = new SparkConf() + .set("spark.checkpoint.compress", "true") + .set("spark.ui.enabled", "false") + sc = new SparkContext("local", "test", conf) + sc.setCheckpointDir(checkpointDir.toString) + val rdd = sc.makeRDD(1 to 20, numSlices = 1) + rdd.checkpoint() + assert(rdd.collect().toSeq === (1 to 20)) + + // Verify that RDD is checkpointed + assert(rdd.firstParent.isInstanceOf[ReliableCheckpointRDD[_]]) + + val checkpointPath = new Path(rdd.getCheckpointFile.get) + val fs = checkpointPath.getFileSystem(sc.hadoopConfiguration) + val checkpointFile = + fs.listStatus(checkpointPath).map(_.getPath).find(_.getName.startsWith("part-")).get + + // Verify the checkpoint file is compressed, in other words, can be decompressed + val compressedInputStream = CompressionCodec.createCodec(conf) + .compressedInputStream(fs.open(checkpointFile)) + try { + ByteStreams.toByteArray(compressedInputStream) + } finally { + compressedInputStream.close() + } + + // Verify that the compressed content can be read back + assert(rdd.collect().toSeq === (1 to 20)) + } finally { + Utils.deleteRecursively(checkpointDir) + } + } +} From 814a61a867ded965433c944c90961df529ac83ab Mon Sep 17 00:00:00 2001 From: Tejas Patil Date: Fri, 28 Apr 2017 23:12:26 -0700 Subject: [PATCH 380/512] [SPARK-20487][SQL] Display `serde` for `HiveTableScan` node in explained plan ## What changes were proposed in this pull request? This was a suggestion by rxin at https://github.com/apache/spark/pull/17780#issuecomment-298073408 ## How was this patch tested? - modified existing unit test - manual testing: ``` scala> hc.sql(" SELECT * FROM tejasp_bucketed_partitioned_1 where name = '' ").explain(true) == Parsed Logical Plan == 'Project [*] +- 'Filter ('name = ) +- 'UnresolvedRelation `tejasp_bucketed_partitioned_1` == Analyzed Logical Plan == user_id: bigint, name: string, ds: string Project [user_id#24L, name#25, ds#26] +- Filter (name#25 = ) +- SubqueryAlias tejasp_bucketed_partitioned_1 +- CatalogRelation `default`.`tejasp_bucketed_partitioned_1`, org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe, [user_id#24L, name#25], [ds#26] == Optimized Logical Plan == Filter (isnotnull(name#25) && (name#25 = )) +- CatalogRelation `default`.`tejasp_bucketed_partitioned_1`, org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe, [user_id#24L, name#25], [ds#26] == Physical Plan == *Filter (isnotnull(name#25) && (name#25 = )) +- HiveTableScan [user_id#24L, name#25, ds#26], CatalogRelation `default`.`tejasp_bucketed_partitioned_1`, org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe, [user_id#24L, name#25], [ds#26] ``` Author: Tejas Patil Closes #17806 from tejasapatil/add_serde. --- .../org/apache/spark/sql/catalyst/trees/TreeNode.scala | 6 +++++- .../apache/spark/sql/hive/execution/HiveExplainSuite.scala | 4 +++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index b091315f24f1..2109c1c23b70 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -444,7 +444,11 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { case None => Nil case Some(null) => Nil case Some(any) => any :: Nil - case table: CatalogTable => table.identifier :: Nil + case table: CatalogTable => + table.storage.serde match { + case Some(serde) => table.identifier :: serde :: Nil + case _ => table.identifier :: Nil + } case other => other :: Nil }.mkString(", ") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala index ebafe6de0c83..aa1ca2909074 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala @@ -43,7 +43,9 @@ class HiveExplainSuite extends QueryTest with SQLTestUtils with TestHiveSingleto test("explain extended command") { checkKeywordsExist(sql(" explain select * from src where key=123 "), - "== Physical Plan ==") + "== Physical Plan ==", + "org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe") + checkKeywordsNotExist(sql(" explain select * from src where key=123 "), "== Parsed Logical Plan ==", "== Analyzed Logical Plan ==", From b28c3bc2020a6936e4ac4c28d49fd832a952af42 Mon Sep 17 00:00:00 2001 From: wangmiao1981 Date: Sat, 29 Apr 2017 10:31:01 -0700 Subject: [PATCH 381/512] [SPARK-20477][SPARKR][DOC] Document R bisecting k-means in R programming guide ## What changes were proposed in this pull request? Add hyper link in the SparkR programming guide. ## How was this patch tested? Build doc and manually check the doc link. Author: wangmiao1981 Closes #17805 from wangmiao1981/doc. --- docs/sparkr.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/sparkr.md b/docs/sparkr.md index c85cfd45c456..16b1ef651242 100644 --- a/docs/sparkr.md +++ b/docs/sparkr.md @@ -497,6 +497,7 @@ SparkR supports the following machine learning algorithms currently: #### Clustering +* [`spark.bisectingKmeans`](api/R/spark.bisectingKmeans.html): [`Bisecting k-means`](ml-clustering.html#bisecting-k-means) * [`spark.gaussianMixture`](api/R/spark.gaussianMixture.html): [`Gaussian Mixture Model (GMM)`](ml-clustering.html#gaussian-mixture-model-gmm) * [`spark.kmeans`](api/R/spark.kmeans.html): [`K-Means`](ml-clustering.html#k-means) * [`spark.lda`](api/R/spark.lda.html): [`Latent Dirichlet Allocation (LDA)`](ml-clustering.html#latent-dirichlet-allocation-lda) From add9d1bba5cf33218a115428a03d3c76a514aa86 Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Sat, 29 Apr 2017 10:51:45 -0700 Subject: [PATCH 382/512] [SPARK-19791][ML] Add doc and example for fpgrowth ## What changes were proposed in this pull request? Add a new section for fpm Add Example for FPGrowth in scala and Java updated: Rewrite transform to be more compact. ## How was this patch tested? local doc generation. Author: Yuhao Yang Closes #17130 from hhbyyh/fpmdoc. --- docs/_data/menu-ml.yaml | 2 + docs/ml-frequent-pattern-mining.md | 87 +++++++++++++++++++ docs/mllib-frequent-pattern-mining.md | 2 +- .../examples/ml/JavaFPGrowthExample.java | 77 ++++++++++++++++ .../src/main/python/ml/fpgrowth_example.py | 56 ++++++++++++ .../spark/examples/ml/FPGrowthExample.scala | 67 ++++++++++++++ .../org/apache/spark/ml/fpm/FPGrowth.scala | 35 ++++---- .../apache/spark/ml/fpm/FPGrowthSuite.scala | 2 + 8 files changed, 310 insertions(+), 18 deletions(-) create mode 100644 docs/ml-frequent-pattern-mining.md create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaFPGrowthExample.java create mode 100644 examples/src/main/python/ml/fpgrowth_example.py create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/FPGrowthExample.scala diff --git a/docs/_data/menu-ml.yaml b/docs/_data/menu-ml.yaml index 0c6b9b20a6e4..047423f75aec 100644 --- a/docs/_data/menu-ml.yaml +++ b/docs/_data/menu-ml.yaml @@ -8,6 +8,8 @@ url: ml-clustering.html - text: Collaborative filtering url: ml-collaborative-filtering.html +- text: Frequent Pattern Mining + url: ml-frequent-pattern-mining.html - text: Model selection and tuning url: ml-tuning.html - text: Advanced topics diff --git a/docs/ml-frequent-pattern-mining.md b/docs/ml-frequent-pattern-mining.md new file mode 100644 index 000000000000..81634de8aade --- /dev/null +++ b/docs/ml-frequent-pattern-mining.md @@ -0,0 +1,87 @@ +--- +layout: global +title: Frequent Pattern Mining +displayTitle: Frequent Pattern Mining +--- + +Mining frequent items, itemsets, subsequences, or other substructures is usually among the +first steps to analyze a large-scale dataset, which has been an active research topic in +data mining for years. +We refer users to Wikipedia's [association rule learning](http://en.wikipedia.org/wiki/Association_rule_learning) +for more information. + +**Table of Contents** + +* This will become a table of contents (this text will be scraped). +{:toc} + +## FP-Growth + +The FP-growth algorithm is described in the paper +[Han et al., Mining frequent patterns without candidate generation](http://dx.doi.org/10.1145/335191.335372), +where "FP" stands for frequent pattern. +Given a dataset of transactions, the first step of FP-growth is to calculate item frequencies and identify frequent items. +Different from [Apriori-like](http://en.wikipedia.org/wiki/Apriori_algorithm) algorithms designed for the same purpose, +the second step of FP-growth uses a suffix tree (FP-tree) structure to encode transactions without generating candidate sets +explicitly, which are usually expensive to generate. +After the second step, the frequent itemsets can be extracted from the FP-tree. +In `spark.mllib`, we implemented a parallel version of FP-growth called PFP, +as described in [Li et al., PFP: Parallel FP-growth for query recommendation](http://dx.doi.org/10.1145/1454008.1454027). +PFP distributes the work of growing FP-trees based on the suffixes of transactions, +and hence is more scalable than a single-machine implementation. +We refer users to the papers for more details. + +`spark.ml`'s FP-growth implementation takes the following (hyper-)parameters: + +* `minSupport`: the minimum support for an itemset to be identified as frequent. + For example, if an item appears 3 out of 5 transactions, it has a support of 3/5=0.6. +* `minConfidence`: minimum confidence for generating Association Rule. Confidence is an indication of how often an + association rule has been found to be true. For example, if in the transactions itemset `X` appears 4 times, `X` + and `Y` co-occur only 2 times, the confidence for the rule `X => Y` is then 2/4 = 0.5. The parameter will not + affect the mining for frequent itemsets, but specify the minimum confidence for generating association rules + from frequent itemsets. +* `numPartitions`: the number of partitions used to distribute the work. By default the param is not set, and + number of partitions of the input dataset is used. + +The `FPGrowthModel` provides: + +* `freqItemsets`: frequent itemsets in the format of DataFrame("items"[Array], "freq"[Long]) +* `associationRules`: association rules generated with confidence above `minConfidence`, in the format of + DataFrame("antecedent"[Array], "consequent"[Array], "confidence"[Double]). +* `transform`: For each transaction in `itemsCol`, the `transform` method will compare its items against the antecedents + of each association rule. If the record contains all the antecedents of a specific association rule, the rule + will be considered as applicable and its consequents will be added to the prediction result. The transform + method will summarize the consequents from all the applicable rules as prediction. The prediction column has + the same data type as `itemsCol` and does not contain existing items in the `itemsCol`. + + +**Examples** + +
    + +
    +Refer to the [Scala API docs](api/scala/index.html#org.apache.spark.ml.fpm.FPGrowth) for more details. + +{% include_example scala/org/apache/spark/examples/ml/FPGrowthExample.scala %} +
    + +
    +Refer to the [Java API docs](api/java/org/apache/spark/ml/fpm/FPGrowth.html) for more details. + +{% include_example java/org/apache/spark/examples/ml/JavaFPGrowthExample.java %} +
    + +
    +Refer to the [Python API docs](api/python/pyspark.ml.html#pyspark.ml.fpm.FPGrowth) for more details. + +{% include_example python/ml/fpgrowth_example.py %} +
    + +
    + +Refer to the [R API docs](api/R/spark.fpGrowth.html) for more details. + +{% include_example r/ml/fpm.R %} +
    + +
    diff --git a/docs/mllib-frequent-pattern-mining.md b/docs/mllib-frequent-pattern-mining.md index 93e3f0b2d226..c9cd7cc85e75 100644 --- a/docs/mllib-frequent-pattern-mining.md +++ b/docs/mllib-frequent-pattern-mining.md @@ -24,7 +24,7 @@ explicitly, which are usually expensive to generate. After the second step, the frequent itemsets can be extracted from the FP-tree. In `spark.mllib`, we implemented a parallel version of FP-growth called PFP, as described in [Li et al., PFP: Parallel FP-growth for query recommendation](http://dx.doi.org/10.1145/1454008.1454027). -PFP distributes the work of growing FP-trees based on the suffices of transactions, +PFP distributes the work of growing FP-trees based on the suffixes of transactions, and hence more scalable than a single-machine implementation. We refer users to the papers for more details. diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaFPGrowthExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaFPGrowthExample.java new file mode 100644 index 000000000000..717ec21c8b20 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaFPGrowthExample.java @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +// $example on$ +import java.util.Arrays; +import java.util.List; + +import org.apache.spark.ml.fpm.FPGrowth; +import org.apache.spark.ml.fpm.FPGrowthModel; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.types.*; +// $example off$ + +/** + * An example demonstrating FPGrowth. + * Run with + *
    + * bin/run-example ml.JavaFPGrowthExample
    + * 
    + */ +public class JavaFPGrowthExample { + public static void main(String[] args) { + SparkSession spark = SparkSession + .builder() + .appName("JavaFPGrowthExample") + .getOrCreate(); + + // $example on$ + List data = Arrays.asList( + RowFactory.create(Arrays.asList("1 2 5".split(" "))), + RowFactory.create(Arrays.asList("1 2 3 5".split(" "))), + RowFactory.create(Arrays.asList("1 2".split(" "))) + ); + StructType schema = new StructType(new StructField[]{ new StructField( + "items", new ArrayType(DataTypes.StringType, true), false, Metadata.empty()) + }); + Dataset itemsDF = spark.createDataFrame(data, schema); + + FPGrowthModel model = new FPGrowth() + .setItemsCol("items") + .setMinSupport(0.5) + .setMinConfidence(0.6) + .fit(itemsDF); + + // Display frequent itemsets. + model.freqItemsets().show(); + + // Display generated association rules. + model.associationRules().show(); + + // transform examines the input items against all the association rules and summarize the + // consequents as prediction + model.transform(itemsDF).show(); + // $example off$ + + spark.stop(); + } +} diff --git a/examples/src/main/python/ml/fpgrowth_example.py b/examples/src/main/python/ml/fpgrowth_example.py new file mode 100644 index 000000000000..c92c3c27abb2 --- /dev/null +++ b/examples/src/main/python/ml/fpgrowth_example.py @@ -0,0 +1,56 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# $example on$ +from pyspark.ml.fpm import FPGrowth +# $example off$ +from pyspark.sql import SparkSession + +""" +An example demonstrating FPGrowth. +Run with: + bin/spark-submit examples/src/main/python/ml/fpgrowth_example.py +""" + +if __name__ == "__main__": + spark = SparkSession\ + .builder\ + .appName("FPGrowthExample")\ + .getOrCreate() + + # $example on$ + df = spark.createDataFrame([ + (0, [1, 2, 5]), + (1, [1, 2, 3, 5]), + (2, [1, 2]) + ], ["id", "items"]) + + fpGrowth = FPGrowth(itemsCol="items", minSupport=0.5, minConfidence=0.6) + model = fpGrowth.fit(df) + + # Display frequent itemsets. + model.freqItemsets.show() + + # Display generated association rules. + model.associationRules.show() + + # transform examines the input items against all the association rules and summarize the + # consequents as prediction + model.transform(df).show() + # $example off$ + + spark.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/FPGrowthExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/FPGrowthExample.scala new file mode 100644 index 000000000000..59110d70de55 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/FPGrowthExample.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml + +// scalastyle:off println + +// $example on$ +import org.apache.spark.ml.fpm.FPGrowth +// $example off$ +import org.apache.spark.sql.SparkSession + +/** + * An example demonstrating FP-Growth. + * Run with + * {{{ + * bin/run-example ml.FPGrowthExample + * }}} + */ +object FPGrowthExample { + + def main(args: Array[String]): Unit = { + val spark = SparkSession + .builder + .appName(s"${this.getClass.getSimpleName}") + .getOrCreate() + import spark.implicits._ + + // $example on$ + val dataset = spark.createDataset(Seq( + "1 2 5", + "1 2 3 5", + "1 2") + ).map(t => t.split(" ")).toDF("items") + + val fpgrowth = new FPGrowth().setItemsCol("items").setMinSupport(0.5).setMinConfidence(0.6) + val model = fpgrowth.fit(dataset) + + // Display frequent itemsets. + model.freqItemsets.show() + + // Display generated association rules. + model.associationRules.show() + + // transform examines the input items against all the association rules and summarize the + // consequents as prediction + model.transform(dataset).show() + // $example off$ + + spark.stop() + } +} +// scalastyle:on println diff --git a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala index d604c1ac001a..8f00daa59f1a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala @@ -17,7 +17,6 @@ package org.apache.spark.ml.fpm -import scala.collection.mutable.ArrayBuffer import scala.reflect.ClassTag import org.apache.hadoop.fs.Path @@ -54,7 +53,7 @@ private[fpm] trait FPGrowthParams extends Params with HasPredictionCol { /** * Minimal support level of the frequent pattern. [0.0, 1.0]. Any pattern that appears - * more than (minSupport * size-of-the-dataset) times will be output + * more than (minSupport * size-of-the-dataset) times will be output in the frequent itemsets. * Default: 0.3 * @group param */ @@ -82,8 +81,8 @@ private[fpm] trait FPGrowthParams extends Params with HasPredictionCol { def getNumPartitions: Int = $(numPartitions) /** - * Minimal confidence for generating Association Rule. - * Note that minConfidence has no effect during fitting. + * Minimal confidence for generating Association Rule. minConfidence will not affect the mining + * for frequent itemsets, but will affect the association rules generation. * Default: 0.8 * @group param */ @@ -118,7 +117,7 @@ private[fpm] trait FPGrowthParams extends Params with HasPredictionCol { * Recommendation. PFP distributes computation in such a way that each worker executes an * independent group of mining tasks. The FP-Growth algorithm is described in * Han et al., Mining frequent patterns without - * candidate generation. Note null values in the feature column are ignored during fit(). + * candidate generation. Note null values in the itemsCol column are ignored during fit(). * * @see * Association rule learning (Wikipedia) @@ -167,7 +166,6 @@ class FPGrowth @Since("2.2.0") ( } val parentModel = mllibFP.run(items) val rows = parentModel.freqItemsets.map(f => Row(f.items, f.freq)) - val schema = StructType(Seq( StructField("items", dataset.schema($(itemsCol)).dataType, nullable = false), StructField("freq", LongType, nullable = false))) @@ -196,7 +194,7 @@ object FPGrowth extends DefaultParamsReadable[FPGrowth] { * :: Experimental :: * Model fitted by FPGrowth. * - * @param freqItemsets frequent items in the format of DataFrame("items"[Seq], "freq"[Long]) + * @param freqItemsets frequent itemsets in the format of DataFrame("items"[Array], "freq"[Long]) */ @Since("2.2.0") @Experimental @@ -244,10 +242,13 @@ class FPGrowthModel private[ml] ( /** * The transform method first generates the association rules according to the frequent itemsets. - * Then for each association rule, it will examine the input items against antecedents and - * summarize the consequents as prediction. The prediction column has the same data type as the - * input column(Array[T]) and will not contain existing items in the input column. The null - * values in the feature columns are treated as empty sets. + * Then for each transaction in itemsCol, the transform method will compare its items against the + * antecedents of each association rule. If the record contains all the antecedents of a + * specific association rule, the rule will be considered as applicable and its consequents + * will be added to the prediction result. The transform method will summarize the consequents + * from all the applicable rules as prediction. The prediction column has the same data type as + * the input column(Array[T]) and will not contain existing items in the input column. The null + * values in the itemsCol columns are treated as empty sets. * WARNING: internally it collects association rules to the driver and uses broadcast for * efficiency. This may bring pressure to driver memory for large set of association rules. */ @@ -335,13 +336,13 @@ private[fpm] object AssociationRules { /** * Computes the association rules with confidence above minConfidence. - * @param dataset DataFrame("items", "freq") containing frequent itemset obtained from - * algorithms like [[FPGrowth]]. + * @param dataset DataFrame("items"[Array], "freq"[Long]) containing frequent itemsets obtained + * from algorithms like [[FPGrowth]]. * @param itemsCol column name for frequent itemsets - * @param freqCol column name for frequent itemsets count - * @param minConfidence minimum confidence for the result association rules - * @return a DataFrame("antecedent", "consequent", "confidence") containing the association - * rules. + * @param freqCol column name for appearance count of the frequent itemsets + * @param minConfidence minimum confidence for generating the association rules + * @return a DataFrame("antecedent"[Array], "consequent"[Array], "confidence"[Double]) + * containing the association rules. */ def getAssociationRulesFromFP[T: ClassTag]( dataset: Dataset[_], diff --git a/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala index 6806cb03bc42..87f8b9034dde 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala @@ -122,6 +122,8 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul .setMinConfidence(0.5678) assert(fpGrowth.getMinSupport === 0.4567) assert(model.getMinConfidence === 0.5678) + // numPartitions should not have default value. + assert(fpGrowth.isDefined(fpGrowth.numPartitions) === false) MLTestingUtils.checkCopyAndUids(fpGrowth, model) ParamsSuite.checkParams(fpGrowth) ParamsSuite.checkParams(model) From ee694cdff6fdb47f23370038f87f8594a80a8f27 Mon Sep 17 00:00:00 2001 From: wangmiao1981 Date: Sat, 29 Apr 2017 10:58:48 -0700 Subject: [PATCH 383/512] [SPARK-20533][SPARKR] SparkR Wrappers Model should be private and value should be lazy ## What changes were proposed in this pull request? MultilayerPerceptronClassifierWrapper model should be private. LogisticRegressionWrapper.scala rFeatures and rCoefficients should be lazy. ## How was this patch tested? Unit tests. Author: wangmiao1981 Closes #17808 from wangmiao1981/lazy. --- .../org/apache/spark/ml/r/LogisticRegressionWrapper.scala | 4 ++-- .../spark/ml/r/MultilayerPerceptronClassifierWrapper.scala | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/LogisticRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/LogisticRegressionWrapper.scala index c96f99cb8343..703bcdf4ca72 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/LogisticRegressionWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/LogisticRegressionWrapper.scala @@ -40,13 +40,13 @@ private[r] class LogisticRegressionWrapper private ( private val lrModel: LogisticRegressionModel = pipeline.stages(1).asInstanceOf[LogisticRegressionModel] - val rFeatures: Array[String] = if (lrModel.getFitIntercept) { + lazy val rFeatures: Array[String] = if (lrModel.getFitIntercept) { Array("(Intercept)") ++ features } else { features } - val rCoefficients: Array[Double] = { + lazy val rCoefficients: Array[Double] = { val numRows = lrModel.coefficientMatrix.numRows val numCols = lrModel.coefficientMatrix.numCols val numColsWithIntercept = if (lrModel.getFitIntercept) numCols + 1 else numCols diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/MultilayerPerceptronClassifierWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/MultilayerPerceptronClassifierWrapper.scala index d34de3093114..48c87743dee6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/MultilayerPerceptronClassifierWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/MultilayerPerceptronClassifierWrapper.scala @@ -36,11 +36,11 @@ private[r] class MultilayerPerceptronClassifierWrapper private ( import MultilayerPerceptronClassifierWrapper._ - val mlpModel: MultilayerPerceptronClassificationModel = + private val mlpModel: MultilayerPerceptronClassificationModel = pipeline.stages(1).asInstanceOf[MultilayerPerceptronClassificationModel] - val weights: Array[Double] = mlpModel.weights.toArray - val layers: Array[Int] = mlpModel.layers + lazy val weights: Array[Double] = mlpModel.weights.toArray + lazy val layers: Array[Int] = mlpModel.layers def transform(dataset: Dataset[_]): DataFrame = { pipeline.transform(dataset) From 70f1bcd7bcd42b30eabcf06a9639363f1ca4b449 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Sat, 29 Apr 2017 11:02:17 -0700 Subject: [PATCH 384/512] [SPARK-20493][R] De-duplicate parse logics for DDL-like type strings in R MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? It seems we are using `SQLUtils.getSQLDataType` for type string in structField. It looks we can replace this with `CatalystSqlParser.parseDataType`. They look similar DDL-like type definitions as below: ```scala scala> Seq(Tuple1(Tuple1("a"))).toDF.show() ``` ``` +---+ | _1| +---+ |[a]| +---+ ``` ```scala scala> Seq(Tuple1(Tuple1("a"))).toDF.select($"_1".cast("struct<_1:string>")).show() ``` ``` +---+ | _1| +---+ |[a]| +---+ ``` Such type strings looks identical when R’s one as below: ```R > write.df(sql("SELECT named_struct('_1', 'a') as struct"), "/tmp/aa", "parquet") > collect(read.df("/tmp/aa", "parquet", structType(structField("struct", "struct<_1:string>")))) struct 1 a ``` R’s one is stricter because we are checking the types via regular expressions in R side ahead. Actual logics there look a bit different but as we check it ahead in R side, it looks replacing it would not introduce (I think) no behaviour changes. To make this sure, the tests dedicated for it were added in SPARK-20105. (It looks `structField` is the only place that calls this method). ## How was this patch tested? Existing tests - https://github.com/apache/spark/blob/master/R/pkg/inst/tests/testthat/test_sparkSQL.R#L143-L194 should cover this. Author: hyukjinkwon Closes #17785 from HyukjinKwon/SPARK-20493. --- R/pkg/R/utils.R | 8 ++++ R/pkg/inst/tests/testthat/test_sparkSQL.R | 13 +++++- R/pkg/inst/tests/testthat/test_utils.R | 6 +-- .../org/apache/spark/sql/api/r/SQLUtils.scala | 43 +------------------ 4 files changed, 24 insertions(+), 46 deletions(-) diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R index fbc89e98847b..d29af00affb9 100644 --- a/R/pkg/R/utils.R +++ b/R/pkg/R/utils.R @@ -864,6 +864,14 @@ captureJVMException <- function(e, method) { # Extract the first message of JVM exception. first <- strsplit(msg[2], "\r?\n\tat")[[1]][1] stop(paste0(rmsg, "no such table - ", first), call. = FALSE) + } else if (any(grep("org.apache.spark.sql.catalyst.parser.ParseException: ", stacktrace))) { + msg <- strsplit(stacktrace, "org.apache.spark.sql.catalyst.parser.ParseException: ", + fixed = TRUE)[[1]] + # Extract "Error in ..." message. + rmsg <- msg[1] + # Extract the first message of JVM exception. + first <- strsplit(msg[2], "\r?\n\tat")[[1]][1] + stop(paste0(rmsg, "parse error - ", first), call. = FALSE) } else { stop(stacktrace, call. = FALSE) } diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index 2cef7191d4f2..1a3d6df437d7 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -150,7 +150,12 @@ test_that("structField type strings", { binary = "BinaryType", boolean = "BooleanType", timestamp = "TimestampType", - date = "DateType") + date = "DateType", + tinyint = "ByteType", + smallint = "ShortType", + int = "IntegerType", + bigint = "LongType", + decimal = "DecimalType(10,0)") complexTypes <- list("map" = "MapType(StringType,IntegerType,true)", "array" = "ArrayType(StringType,true)", @@ -174,7 +179,11 @@ test_that("structField type strings", { numeric = "numeric", character = "character", raw = "raw", - logical = "logical") + logical = "logical", + short = "short", + varchar = "varchar", + long = "long", + char = "char") complexErrors <- list("map" = " integer", "array" = "String", diff --git a/R/pkg/inst/tests/testthat/test_utils.R b/R/pkg/inst/tests/testthat/test_utils.R index 6d006eccf665..1ca383da26ec 100644 --- a/R/pkg/inst/tests/testthat/test_utils.R +++ b/R/pkg/inst/tests/testthat/test_utils.R @@ -167,13 +167,13 @@ test_that("convertToJSaveMode", { }) test_that("captureJVMException", { - method <- "getSQLDataType" + method <- "createStructField" expect_error(tryCatch(callJStatic("org.apache.spark.sql.api.r.SQLUtils", method, - "unknown"), + "col", "unknown", TRUE), error = function(e) { captureJVMException(e, method) }), - "Error in getSQLDataType : illegal argument - Invalid type unknown") + "parse error - .*DataType unknown.*not supported.") }) test_that("hashCode", { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala index a26d00411fba..d94e528a3ad4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala @@ -31,6 +31,7 @@ import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema +import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.execution.command.ShowTablesCommand import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION import org.apache.spark.sql.types._ @@ -92,48 +93,8 @@ private[sql] object SQLUtils extends Logging { def r: Regex = new Regex(sc.parts.mkString, sc.parts.tail.map(_ => "x"): _*) } - def getSQLDataType(dataType: String): DataType = { - dataType match { - case "byte" => org.apache.spark.sql.types.ByteType - case "integer" => org.apache.spark.sql.types.IntegerType - case "float" => org.apache.spark.sql.types.FloatType - case "double" => org.apache.spark.sql.types.DoubleType - case "numeric" => org.apache.spark.sql.types.DoubleType - case "character" => org.apache.spark.sql.types.StringType - case "string" => org.apache.spark.sql.types.StringType - case "binary" => org.apache.spark.sql.types.BinaryType - case "raw" => org.apache.spark.sql.types.BinaryType - case "logical" => org.apache.spark.sql.types.BooleanType - case "boolean" => org.apache.spark.sql.types.BooleanType - case "timestamp" => org.apache.spark.sql.types.TimestampType - case "date" => org.apache.spark.sql.types.DateType - case r"\Aarray<(.+)${elemType}>\Z" => - org.apache.spark.sql.types.ArrayType(getSQLDataType(elemType)) - case r"\Amap<(.+)${keyType},(.+)${valueType}>\Z" => - if (keyType != "string" && keyType != "character") { - throw new IllegalArgumentException("Key type of a map must be string or character") - } - org.apache.spark.sql.types.MapType(getSQLDataType(keyType), getSQLDataType(valueType)) - case r"\Astruct<(.+)${fieldsStr}>\Z" => - if (fieldsStr(fieldsStr.length - 1) == ',') { - throw new IllegalArgumentException(s"Invalid type $dataType") - } - val fields = fieldsStr.split(",") - val structFields = fields.map { field => - field match { - case r"\A(.+)${fieldName}:(.+)${fieldType}\Z" => - createStructField(fieldName, fieldType, true) - - case _ => throw new IllegalArgumentException(s"Invalid type $dataType") - } - } - createStructType(structFields) - case _ => throw new IllegalArgumentException(s"Invalid type $dataType") - } - } - def createStructField(name: String, dataType: String, nullable: Boolean): StructField = { - val dtObj = getSQLDataType(dataType) + val dtObj = CatalystSqlParser.parseDataType(dataType) StructField(name, dtObj, nullable) } From d228cd0b0243773a1c834414a240d1c553ab7af6 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Sat, 29 Apr 2017 13:46:40 -0700 Subject: [PATCH 385/512] [SPARK-20442][PYTHON][DOCS] Fill up documentations for functions in Column API in PySpark ## What changes were proposed in this pull request? This PR proposes to fill up the documentation with examples for `bitwiseOR`, `bitwiseAND`, `bitwiseXOR`. `contains`, `asc` and `desc` in `Column` API. Also, this PR fixes minor typos in the documentation and matches some of the contents between Scala doc and Python doc. Lastly, this PR suggests to use `spark` rather than `sc` in doc tests in `Column` for Python documentation. ## How was this patch tested? Doc tests were added and manually tested with the commands below: `./python/run-tests.py --module pyspark-sql` `./python/run-tests.py --module pyspark-sql --python-executable python3` `./dev/lint-python` Output was checked via `make html` under `./python/docs`. The snapshots will be left on the codes with comments. Author: hyukjinkwon Closes #17737 from HyukjinKwon/SPARK-20442. --- python/pyspark/sql/column.py | 104 ++++++++++++++---- .../expressions/bitwiseExpressions.scala | 2 +- .../scala/org/apache/spark/sql/Column.scala | 31 +++--- 3 files changed, 99 insertions(+), 38 deletions(-) diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py index 46c1707cb6c3..b8df37f25180 100644 --- a/python/pyspark/sql/column.py +++ b/python/pyspark/sql/column.py @@ -185,9 +185,43 @@ def __contains__(self, item): "in a string column or 'array_contains' function for an array column.") # bitwise operators - bitwiseOR = _bin_op("bitwiseOR") - bitwiseAND = _bin_op("bitwiseAND") - bitwiseXOR = _bin_op("bitwiseXOR") + _bitwiseOR_doc = """ + Compute bitwise OR of this expression with another expression. + + :param other: a value or :class:`Column` to calculate bitwise or(|) against + this :class:`Column`. + + >>> from pyspark.sql import Row + >>> df = spark.createDataFrame([Row(a=170, b=75)]) + >>> df.select(df.a.bitwiseOR(df.b)).collect() + [Row((a | b)=235)] + """ + _bitwiseAND_doc = """ + Compute bitwise AND of this expression with another expression. + + :param other: a value or :class:`Column` to calculate bitwise and(&) against + this :class:`Column`. + + >>> from pyspark.sql import Row + >>> df = spark.createDataFrame([Row(a=170, b=75)]) + >>> df.select(df.a.bitwiseAND(df.b)).collect() + [Row((a & b)=10)] + """ + _bitwiseXOR_doc = """ + Compute bitwise XOR of this expression with another expression. + + :param other: a value or :class:`Column` to calculate bitwise xor(^) against + this :class:`Column`. + + >>> from pyspark.sql import Row + >>> df = spark.createDataFrame([Row(a=170, b=75)]) + >>> df.select(df.a.bitwiseXOR(df.b)).collect() + [Row((a ^ b)=225)] + """ + + bitwiseOR = _bin_op("bitwiseOR", _bitwiseOR_doc) + bitwiseAND = _bin_op("bitwiseAND", _bitwiseAND_doc) + bitwiseXOR = _bin_op("bitwiseXOR", _bitwiseXOR_doc) @since(1.3) def getItem(self, key): @@ -195,7 +229,7 @@ def getItem(self, key): An expression that gets an item at position ``ordinal`` out of a list, or gets an item by key out of a dict. - >>> df = sc.parallelize([([1, 2], {"key": "value"})]).toDF(["l", "d"]) + >>> df = spark.createDataFrame([([1, 2], {"key": "value"})], ["l", "d"]) >>> df.select(df.l.getItem(0), df.d.getItem("key")).show() +----+------+ |l[0]|d[key]| @@ -217,7 +251,7 @@ def getField(self, name): An expression that gets a field by name in a StructField. >>> from pyspark.sql import Row - >>> df = sc.parallelize([Row(r=Row(a=1, b="b"))]).toDF() + >>> df = spark.createDataFrame([Row(r=Row(a=1, b="b"))]) >>> df.select(df.r.getField("b")).show() +---+ |r.b| @@ -250,8 +284,17 @@ def __iter__(self): raise TypeError("Column is not iterable") # string methods + _contains_doc = """ + Contains the other element. Returns a boolean :class:`Column` based on a string match. + + :param other: string in line + + >>> df.filter(df.name.contains('o')).collect() + [Row(age=5, name=u'Bob')] + """ _rlike_doc = """ - Return a Boolean :class:`Column` based on a regex match. + SQL RLIKE expression (LIKE with Regex). Returns a boolean :class:`Column` based on a regex + match. :param other: an extended regex expression @@ -259,7 +302,7 @@ def __iter__(self): [Row(age=2, name=u'Alice')] """ _like_doc = """ - Return a Boolean :class:`Column` based on a SQL LIKE match. + SQL like expression. Returns a boolean :class:`Column` based on a SQL LIKE match. :param other: a SQL LIKE pattern @@ -269,9 +312,9 @@ def __iter__(self): [Row(age=2, name=u'Alice')] """ _startswith_doc = """ - Return a Boolean :class:`Column` based on a string match. + String starts with. Returns a boolean :class:`Column` based on a string match. - :param other: string at end of line (do not use a regex `^`) + :param other: string at start of line (do not use a regex `^`) >>> df.filter(df.name.startswith('Al')).collect() [Row(age=2, name=u'Alice')] @@ -279,7 +322,7 @@ def __iter__(self): [] """ _endswith_doc = """ - Return a Boolean :class:`Column` based on matching end of string. + String ends with. Returns a boolean :class:`Column` based on a string match. :param other: string at end of line (do not use a regex `$`) @@ -289,7 +332,7 @@ def __iter__(self): [] """ - contains = _bin_op("contains") + contains = ignore_unicode_prefix(_bin_op("contains", _contains_doc)) rlike = ignore_unicode_prefix(_bin_op("rlike", _rlike_doc)) like = ignore_unicode_prefix(_bin_op("like", _like_doc)) startswith = ignore_unicode_prefix(_bin_op("startsWith", _startswith_doc)) @@ -337,27 +380,40 @@ def isin(self, *cols): return Column(jc) # order - asc = _unary_op("asc", "Returns a sort expression based on the" - " ascending order of the given column name.") - desc = _unary_op("desc", "Returns a sort expression based on the" - " descending order of the given column name.") + _asc_doc = """ + Returns a sort expression based on the ascending order of the given column name + + >>> from pyspark.sql import Row + >>> df = spark.createDataFrame([Row(name=u'Tom', height=80), Row(name=u'Alice', height=None)]) + >>> df.select(df.name).orderBy(df.name.asc()).collect() + [Row(name=u'Alice'), Row(name=u'Tom')] + """ + _desc_doc = """ + Returns a sort expression based on the descending order of the given column name. + + >>> from pyspark.sql import Row + >>> df = spark.createDataFrame([Row(name=u'Tom', height=80), Row(name=u'Alice', height=None)]) + >>> df.select(df.name).orderBy(df.name.desc()).collect() + [Row(name=u'Tom'), Row(name=u'Alice')] + """ + + asc = ignore_unicode_prefix(_unary_op("asc", _asc_doc)) + desc = ignore_unicode_prefix(_unary_op("desc", _desc_doc)) _isNull_doc = """ - True if the current expression is null. Often combined with - :func:`DataFrame.filter` to select rows with null values. + True if the current expression is null. >>> from pyspark.sql import Row - >>> df2 = sc.parallelize([Row(name=u'Tom', height=80), Row(name=u'Alice', height=None)]).toDF() - >>> df2.filter(df2.height.isNull()).collect() + >>> df = spark.createDataFrame([Row(name=u'Tom', height=80), Row(name=u'Alice', height=None)]) + >>> df.filter(df.height.isNull()).collect() [Row(height=None, name=u'Alice')] """ _isNotNull_doc = """ - True if the current expression is null. Often combined with - :func:`DataFrame.filter` to select rows with non-null values. + True if the current expression is NOT null. >>> from pyspark.sql import Row - >>> df2 = sc.parallelize([Row(name=u'Tom', height=80), Row(name=u'Alice', height=None)]).toDF() - >>> df2.filter(df2.height.isNotNull()).collect() + >>> df = spark.createDataFrame([Row(name=u'Tom', height=80), Row(name=u'Alice', height=None)]) + >>> df.filter(df.height.isNotNull()).collect() [Row(height=80, name=u'Tom')] """ @@ -527,7 +583,7 @@ def _test(): .appName("sql.column tests")\ .getOrCreate() sc = spark.sparkContext - globs['sc'] = sc + globs['spark'] = spark globs['df'] = sc.parallelize([(2, 'Alice'), (5, 'Bob')]) \ .toDF(StructType([StructField('age', IntegerType()), StructField('name', StringType())])) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala index 291804077143..425efbb6c96c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala @@ -86,7 +86,7 @@ case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmet } /** - * A function that calculates bitwise xor of two numbers. + * A function that calculates bitwise xor({@literal ^}) of two numbers. * * Code generation inherited from BinaryArithmetic. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 43de2de7e709..b23ab1fa3514 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -779,7 +779,7 @@ class Column(val expr: Expression) extends Logging { def isin(list: Any*): Column = withExpr { In(expr, list.map(lit(_).expr)) } /** - * SQL like expression. + * SQL like expression. Returns a boolean column based on a SQL LIKE match. * * @group expr_ops * @since 1.3.0 @@ -787,7 +787,8 @@ class Column(val expr: Expression) extends Logging { def like(literal: String): Column = withExpr { Like(expr, lit(literal).expr) } /** - * SQL RLIKE expression (LIKE with Regex). + * SQL RLIKE expression (LIKE with Regex). Returns a boolean column based on a regex + * match. * * @group expr_ops * @since 1.3.0 @@ -838,7 +839,7 @@ class Column(val expr: Expression) extends Logging { } /** - * Contains the other element. + * Contains the other element. Returns a boolean column based on a string match. * * @group expr_ops * @since 1.3.0 @@ -846,7 +847,7 @@ class Column(val expr: Expression) extends Logging { def contains(other: Any): Column = withExpr { Contains(expr, lit(other).expr) } /** - * String starts with. + * String starts with. Returns a boolean column based on a string match. * * @group expr_ops * @since 1.3.0 @@ -854,7 +855,7 @@ class Column(val expr: Expression) extends Logging { def startsWith(other: Column): Column = withExpr { StartsWith(expr, lit(other).expr) } /** - * String starts with another string literal. + * String starts with another string literal. Returns a boolean column based on a string match. * * @group expr_ops * @since 1.3.0 @@ -862,7 +863,7 @@ class Column(val expr: Expression) extends Logging { def startsWith(literal: String): Column = this.startsWith(lit(literal)) /** - * String ends with. + * String ends with. Returns a boolean column based on a string match. * * @group expr_ops * @since 1.3.0 @@ -870,7 +871,7 @@ class Column(val expr: Expression) extends Logging { def endsWith(other: Column): Column = withExpr { EndsWith(expr, lit(other).expr) } /** - * String ends with another string literal. + * String ends with another string literal. Returns a boolean column based on a string match. * * @group expr_ops * @since 1.3.0 @@ -1008,7 +1009,7 @@ class Column(val expr: Expression) extends Logging { def cast(to: String): Column = cast(CatalystSqlParser.parseDataType(to)) /** - * Returns an ordering used in sorting. + * Returns a sort expression based on the descending order of the column. * {{{ * // Scala * df.sort(df("age").desc) @@ -1023,7 +1024,8 @@ class Column(val expr: Expression) extends Logging { def desc: Column = withExpr { SortOrder(expr, Descending) } /** - * Returns a descending ordering used in sorting, where null values appear before non-null values. + * Returns a sort expression based on the descending order of the column, + * and null values appear before non-null values. * {{{ * // Scala: sort a DataFrame by age column in descending order and null values appearing first. * df.sort(df("age").desc_nulls_first) @@ -1038,7 +1040,8 @@ class Column(val expr: Expression) extends Logging { def desc_nulls_first: Column = withExpr { SortOrder(expr, Descending, NullsFirst, Set.empty) } /** - * Returns a descending ordering used in sorting, where null values appear after non-null values. + * Returns a sort expression based on the descending order of the column, + * and null values appear after non-null values. * {{{ * // Scala: sort a DataFrame by age column in descending order and null values appearing last. * df.sort(df("age").desc_nulls_last) @@ -1053,7 +1056,7 @@ class Column(val expr: Expression) extends Logging { def desc_nulls_last: Column = withExpr { SortOrder(expr, Descending, NullsLast, Set.empty) } /** - * Returns an ascending ordering used in sorting. + * Returns a sort expression based on ascending order of the column. * {{{ * // Scala: sort a DataFrame by age column in ascending order. * df.sort(df("age").asc) @@ -1068,7 +1071,8 @@ class Column(val expr: Expression) extends Logging { def asc: Column = withExpr { SortOrder(expr, Ascending) } /** - * Returns an ascending ordering used in sorting, where null values appear before non-null values. + * Returns a sort expression based on ascending order of the column, + * and null values return before non-null values. * {{{ * // Scala: sort a DataFrame by age column in ascending order and null values appearing first. * df.sort(df("age").asc_nulls_last) @@ -1083,7 +1087,8 @@ class Column(val expr: Expression) extends Logging { def asc_nulls_first: Column = withExpr { SortOrder(expr, Ascending, NullsFirst, Set.empty) } /** - * Returns an ordering used in sorting, where null values appear after non-null values. + * Returns a sort expression based on ascending order of the column, + * and null values appear after non-null values. * {{{ * // Scala: sort a DataFrame by age column in ascending order and null values appearing last. * df.sort(df("age").asc_nulls_last) From 4d99b95ad0d0c7ef909c8e492ec45e94cf0189b4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=83=AD=E5=B0=8F=E9=BE=99=2010207633?= Date: Sun, 30 Apr 2017 09:06:25 +0100 Subject: [PATCH 386/512] [SPARK-20521][DOC][CORE] The default of 'spark.worker.cleanup.appDataTtl' should be 604800 in spark-standalone.md MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? Currently, our project needs to be set to clean up the worker directory cleanup cycle is three days. When I follow http://spark.apache.org/docs/latest/spark-standalone.html, configure the 'spark.worker.cleanup.appDataTtl' parameter, I configured to 3 * 24 * 3600. When I start the spark service, the startup fails, and the worker log displays the error log as follows: 2017-04-28 15:02:03,306 INFO Utils: Successfully started service 'sparkWorker' on port 48728. Exception in thread "main" java.lang.NumberFormatException: For input string: "3 * 24 * 3600" at java.lang.NumberFormatException.forInputString(NumberFormatException.java:65) at java.lang.Long.parseLong(Long.java:430) at java.lang.Long.parseLong(Long.java:483) at scala.collection.immutable.StringLike$class.toLong(StringLike.scala:276) at scala.collection.immutable.StringOps.toLong(StringOps.scala:29) at org.apache.spark.SparkConf$$anonfun$getLong$2.apply(SparkConf.scala:380) at org.apache.spark.SparkConf$$anonfun$getLong$2.apply(SparkConf.scala:380) at scala.Option.map(Option.scala:146) at org.apache.spark.SparkConf.getLong(SparkConf.scala:380) at org.apache.spark.deploy.worker.Worker.(Worker.scala:100) at org.apache.spark.deploy.worker.Worker$.startRpcEnvAndEndpoint(Worker.scala:730) at org.apache.spark.deploy.worker.Worker$.main(Worker.scala:709) at org.apache.spark.deploy.worker.Worker.main(Worker.scala) **Because we put 7 * 24 * 3600 as a string, forced to convert to the dragon type, will lead to problems in the program.** **So I think the default value of the current configuration should be a specific long value, rather than 7 * 24 * 3600,should be 604800. Because it would mislead users for similar configurations, resulting in spark start failure.** ## How was this patch tested? manual tests Please review http://spark.apache.org/contributing.html before opening a pull request. Author: 郭小龙 10207633 Author: guoxiaolong Author: guoxiaolongzte Closes #17798 from guoxiaolongzte/SPARK-20521. --- docs/spark-standalone.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/spark-standalone.md b/docs/spark-standalone.md index 1c0b60f7b934..34ced9ed7b46 100644 --- a/docs/spark-standalone.md +++ b/docs/spark-standalone.md @@ -242,7 +242,7 @@ SPARK_WORKER_OPTS supports the following system properties:
    - + ``` Here `num` field represents number of attempts, this is not equal to REST APIs. In the REST API, if attempt id is not existed the URL should be `api/v1/applications//logs`, otherwise the URL should be `api/v1/applications///logs`. Using `` to represent `` will lead to the issue of "no such app". Manual verification. CC ajbozarth can you please review this change, since you add this feature before? Thanks! Author: jerryshao Closes #17795 from jerryshao/SPARK-20517. --- .../org/apache/spark/ui/static/historypage-template.html | 2 +- .../main/resources/org/apache/spark/ui/static/historypage.js | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/core/src/main/resources/org/apache/spark/ui/static/historypage-template.html b/core/src/main/resources/org/apache/spark/ui/static/historypage-template.html index 42e2d9abdeb5..6ba3b092dc65 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/historypage-template.html +++ b/core/src/main/resources/org/apache/spark/ui/static/historypage-template.html @@ -77,7 +77,7 @@ - + {{/attempts}} {{/applications}} diff --git a/core/src/main/resources/org/apache/spark/ui/static/historypage.js b/core/src/main/resources/org/apache/spark/ui/static/historypage.js index 54810edaf146..1f89306403cd 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/historypage.js +++ b/core/src/main/resources/org/apache/spark/ui/static/historypage.js @@ -120,6 +120,9 @@ $(document).ready(function() { attempt["startTime"] = formatDate(attempt["startTime"]); attempt["endTime"] = formatDate(attempt["endTime"]); attempt["lastUpdated"] = formatDate(attempt["lastUpdated"]); + attempt["log"] = uiRoot + "/api/v1/applications/" + id + "/" + + (attempt.hasOwnProperty("attemptId") ? attempt["attemptId"] + "/" : "") + "logs"; + var app_clone = {"id" : id, "name" : name, "num" : num, "attempts" : [attempt]}; array.push(app_clone); } From 6fc6cf88d871f5b05b0ad1a504e0d6213cf9d331 Mon Sep 17 00:00:00 2001 From: Kunal Khamar Date: Mon, 1 May 2017 11:37:30 -0700 Subject: [PATCH 395/512] [SPARK-20464][SS] Add a job group and description for streaming queries and fix cancellation of running jobs using the job group ## What changes were proposed in this pull request? Job group: adding a job group is required to properly cancel running jobs related to a query. Description: the new description makes it easier to group the batches of a query by sorting by name in the Spark Jobs UI. ## How was this patch tested? - Unit tests - UI screenshot - Order by job id: ![screen shot 2017-04-27 at 5 10 09 pm](https://cloud.githubusercontent.com/assets/7865120/25509468/15452274-2b6e-11e7-87ba-d929816688cf.png) - Order by description: ![screen shot 2017-04-27 at 5 10 22 pm](https://cloud.githubusercontent.com/assets/7865120/25509474/1c298512-2b6e-11e7-99b8-fef1ef7665c1.png) - Order by job id (no query name): ![screen shot 2017-04-27 at 5 21 33 pm](https://cloud.githubusercontent.com/assets/7865120/25509482/28c96dc8-2b6e-11e7-8df0-9d3cdbb05e36.png) - Order by description (no query name): ![screen shot 2017-04-27 at 5 21 44 pm](https://cloud.githubusercontent.com/assets/7865120/25509489/37674742-2b6e-11e7-9357-b5c38ec16ac4.png) Author: Kunal Khamar Closes #17765 from kunalkhamar/sc-6696. --- .../scala/org/apache/spark/ui/UIUtils.scala | 2 +- .../execution/streaming/StreamExecution.scala | 12 ++++ .../spark/sql/streaming/StreamSuite.scala | 66 +++++++++++++++++++ 3 files changed, 79 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala index e53d6907bc40..79b0d81af52b 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala @@ -446,7 +446,7 @@ private[spark] object UIUtils extends Logging { val xml = XML.loadString(s"""$desc""") // Verify that this has only anchors and span (we are wrapping in span) - val allowedNodeLabels = Set("a", "span") + val allowedNodeLabels = Set("a", "span", "br") val illegalNodes = xml \\ "_" filterNot { case node: Node => allowedNodeLabels.contains(node.label) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index bcf0d970f7ec..affc2018c43c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -252,6 +252,8 @@ class StreamExecution( */ private def runBatches(): Unit = { try { + sparkSession.sparkContext.setJobGroup(runId.toString, getBatchDescriptionString, + interruptOnCancel = true) if (sparkSession.sessionState.conf.streamingMetricsEnabled) { sparkSession.sparkContext.env.metricsSystem.registerSource(streamMetrics) } @@ -289,6 +291,7 @@ class StreamExecution( if (currentBatchId < 0) { // We'll do this initialization only once populateStartOffsets(sparkSessionToRunBatches) + sparkSession.sparkContext.setJobDescription(getBatchDescriptionString) logDebug(s"Stream running from $committedOffsets to $availableOffsets") } else { constructNextBatch() @@ -308,6 +311,7 @@ class StreamExecution( logDebug(s"batch ${currentBatchId} committed") // We'll increase currentBatchId after we complete processing current batch's data currentBatchId += 1 + sparkSession.sparkContext.setJobDescription(getBatchDescriptionString) } else { currentStatus = currentStatus.copy(isDataAvailable = false) updateStatusMessage("Waiting for data to arrive") @@ -684,8 +688,11 @@ class StreamExecution( // intentionally state.set(TERMINATED) if (microBatchThread.isAlive) { + sparkSession.sparkContext.cancelJobGroup(runId.toString) microBatchThread.interrupt() microBatchThread.join() + // microBatchThread may spawn new jobs, so we need to cancel again to prevent a leak + sparkSession.sparkContext.cancelJobGroup(runId.toString) } logInfo(s"Query $prettyIdString was stopped") } @@ -825,6 +832,11 @@ class StreamExecution( } } + private def getBatchDescriptionString: String = { + val batchDescription = if (currentBatchId < 0) "init" else currentBatchId.toString + Option(name).map(_ + "
    ").getOrElse("") + + s"id = $id
    runId = $runId
    batch = $batchDescription" + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala index 13fe51a55773..01ea62a9de4d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala @@ -25,6 +25,8 @@ import scala.util.control.ControlThrowable import org.apache.commons.io.FileUtils +import org.apache.spark.SparkContext +import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.streaming.InternalOutputModes import org.apache.spark.sql.execution.command.ExplainCommand @@ -500,6 +502,70 @@ class StreamSuite extends StreamTest { } } } + + test("calling stop() on a query cancels related jobs") { + val input = MemoryStream[Int] + val query = input + .toDS() + .map { i => + while (!org.apache.spark.TaskContext.get().isInterrupted()) { + // keep looping till interrupted by query.stop() + Thread.sleep(100) + } + i + } + .writeStream + .format("console") + .start() + + input.addData(1) + // wait for jobs to start + eventually(timeout(streamingTimeout)) { + assert(sparkContext.statusTracker.getActiveJobIds().nonEmpty) + } + + query.stop() + // make sure jobs are stopped + eventually(timeout(streamingTimeout)) { + assert(sparkContext.statusTracker.getActiveJobIds().isEmpty) + } + } + + test("batch id is updated correctly in the job description") { + val queryName = "memStream" + @volatile var jobDescription: String = null + def assertDescContainsQueryNameAnd(batch: Integer): Unit = { + // wait for listener event to be processed + spark.sparkContext.listenerBus.waitUntilEmpty(streamingTimeout.toMillis) + assert(jobDescription.contains(queryName) && jobDescription.contains(s"batch = $batch")) + } + + spark.sparkContext.addSparkListener(new SparkListener { + override def onJobStart(jobStart: SparkListenerJobStart): Unit = { + jobDescription = jobStart.properties.getProperty(SparkContext.SPARK_JOB_DESCRIPTION) + } + }) + + val input = MemoryStream[Int] + val query = input + .toDS() + .map(_ + 1) + .writeStream + .format("memory") + .queryName(queryName) + .start() + + input.addData(1) + query.processAllAvailable() + assertDescContainsQueryNameAnd(batch = 0) + input.addData(2, 3) + query.processAllAvailable() + assertDescContainsQueryNameAnd(batch = 1) + input.addData(4) + query.processAllAvailable() + assertDescContainsQueryNameAnd(batch = 2) + query.stop() + } } abstract class FakeSource extends StreamSourceProvider { From 2b2dd08e975dd7fbf261436aa877f1d7497ed31f Mon Sep 17 00:00:00 2001 From: Ryan Blue Date: Mon, 1 May 2017 14:48:02 -0700 Subject: [PATCH 396/512] [SPARK-20540][CORE] Fix unstable executor requests. There are two problems fixed in this commit. First, the ExecutorAllocationManager sets a timeout to avoid requesting executors too often. However, the timeout is always updated based on its value and a timeout, not the current time. If the call is delayed by locking for more than the ongoing scheduler timeout, the manager will request more executors on every run. This seems to be the main cause of SPARK-20540. The second problem is that the total number of requested executors is not tracked by the CoarseGrainedSchedulerBackend. Instead, it calculates the value based on the current status of 3 variables: the number of known executors, the number of executors that have been killed, and the number of pending executors. But, the number of pending executors is never less than 0, even though there may be more known than requested. When executors are killed and not replaced, this can cause the request sent to YARN to be incorrect because there were too many executors due to the scheduler's state being slightly out of date. This is fixed by tracking the currently requested size explicitly. ## How was this patch tested? Existing tests. Author: Ryan Blue Closes #17813 from rdblue/SPARK-20540-fix-dynamic-allocation. --- .../spark/ExecutorAllocationManager.scala | 2 +- .../CoarseGrainedSchedulerBackend.scala | 32 ++++++++++++++++--- .../StandaloneDynamicAllocationSuite.scala | 6 ++-- 3 files changed, 33 insertions(+), 7 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala index 261b3329a7b9..fcc72ff49276 100644 --- a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala +++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala @@ -331,7 +331,7 @@ private[spark] class ExecutorAllocationManager( val delta = addExecutors(maxNeeded) logDebug(s"Starting timer to add more executors (to " + s"expire in $sustainedSchedulerBacklogTimeoutS seconds)") - addTime += sustainedSchedulerBacklogTimeoutS * 1000 + addTime = now + (sustainedSchedulerBacklogTimeoutS * 1000) delta } else { 0 diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 4eedaaea6119..dc82bb770472 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -69,6 +69,10 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp // `CoarseGrainedSchedulerBackend.this`. private val executorDataMap = new HashMap[String, ExecutorData] + // Number of executors requested by the cluster manager, [[ExecutorAllocationManager]] + @GuardedBy("CoarseGrainedSchedulerBackend.this") + private var requestedTotalExecutors = 0 + // Number of executors requested from the cluster manager that have not registered yet @GuardedBy("CoarseGrainedSchedulerBackend.this") private var numPendingExecutors = 0 @@ -413,6 +417,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp * */ protected def reset(): Unit = { val executors = synchronized { + requestedTotalExecutors = 0 numPendingExecutors = 0 executorsPendingToRemove.clear() Set() ++ executorDataMap.keys @@ -487,12 +492,21 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp logInfo(s"Requesting $numAdditionalExecutors additional executor(s) from the cluster manager") val response = synchronized { + requestedTotalExecutors += numAdditionalExecutors numPendingExecutors += numAdditionalExecutors logDebug(s"Number of pending executors is now $numPendingExecutors") + if (requestedTotalExecutors != + (numExistingExecutors + numPendingExecutors - executorsPendingToRemove.size)) { + logDebug( + s"""requestExecutors($numAdditionalExecutors): Executor request doesn't match: + |requestedTotalExecutors = $requestedTotalExecutors + |numExistingExecutors = $numExistingExecutors + |numPendingExecutors = $numPendingExecutors + |executorsPendingToRemove = ${executorsPendingToRemove.size}""".stripMargin) + } // Account for executors pending to be added or removed - doRequestTotalExecutors( - numExistingExecutors + numPendingExecutors - executorsPendingToRemove.size) + doRequestTotalExecutors(requestedTotalExecutors) } defaultAskTimeout.awaitResult(response) @@ -524,6 +538,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp } val response = synchronized { + this.requestedTotalExecutors = numExecutors this.localityAwareTasks = localityAwareTasks this.hostToLocalTaskCount = hostToLocalTaskCount @@ -589,8 +604,17 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp // take into account executors that are pending to be added or removed. val adjustTotalExecutors = if (!replace) { - doRequestTotalExecutors( - numExistingExecutors + numPendingExecutors - executorsPendingToRemove.size) + requestedTotalExecutors = math.max(requestedTotalExecutors - executorsToKill.size, 0) + if (requestedTotalExecutors != + (numExistingExecutors + numPendingExecutors - executorsPendingToRemove.size)) { + logDebug( + s"""killExecutors($executorIds, $replace, $force): Executor counts do not match: + |requestedTotalExecutors = $requestedTotalExecutors + |numExistingExecutors = $numExistingExecutors + |numPendingExecutors = $numPendingExecutors + |executorsPendingToRemove = ${executorsPendingToRemove.size}""".stripMargin) + } + doRequestTotalExecutors(requestedTotalExecutors) } else { numPendingExecutors += knownExecutors.size Future.successful(true) diff --git a/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala b/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala index 9839dcf8535d..bf7480d79f8a 100644 --- a/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala @@ -356,12 +356,13 @@ class StandaloneDynamicAllocationSuite test("kill the same executor twice (SPARK-9795)") { sc = new SparkContext(appConf) val appId = sc.applicationId + sc.requestExecutors(2) eventually(timeout(10.seconds), interval(10.millis)) { val apps = getApplications() assert(apps.size === 1) assert(apps.head.id === appId) assert(apps.head.executors.size === 2) - assert(apps.head.getExecutorLimit === Int.MaxValue) + assert(apps.head.getExecutorLimit === 2) } // sync executors between the Master and the driver, needed because // the driver refuses to kill executors it does not know about @@ -380,12 +381,13 @@ class StandaloneDynamicAllocationSuite test("the pending replacement executors should not be lost (SPARK-10515)") { sc = new SparkContext(appConf) val appId = sc.applicationId + sc.requestExecutors(2) eventually(timeout(10.seconds), interval(10.millis)) { val apps = getApplications() assert(apps.size === 1) assert(apps.head.id === appId) assert(apps.head.executors.size === 2) - assert(apps.head.getExecutorLimit === Int.MaxValue) + assert(apps.head.getExecutorLimit === 2) } // sync executors between the Master and the driver, needed because // the driver refuses to kill executors it does not know about From af726cd6117de05c6e3b9616b8699d884a53651b Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Mon, 1 May 2017 17:01:05 -0700 Subject: [PATCH 397/512] [SPARK-20459][SQL] JdbcUtils throws IllegalStateException: Cause already initialized after getting SQLException ## What changes were proposed in this pull request? Avoid failing to initCause on JDBC exception with cause initialized to null ## How was this patch tested? Existing tests Author: Sean Owen Closes #17800 from srowen/SPARK-20459. --- .../sql/execution/datasources/jdbc/JdbcUtils.scala | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index 5fc3c2753b6c..0183805d5625 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -652,8 +652,17 @@ object JdbcUtils extends Logging { case e: SQLException => val cause = e.getNextException if (cause != null && e.getCause != cause) { + // If there is no cause already, set 'next exception' as cause. If cause is null, + // it *may* be because no cause was set yet if (e.getCause == null) { - e.initCause(cause) + try { + e.initCause(cause) + } catch { + // Or it may be null because the cause *was* explicitly initialized, to *null*, + // in which case this fails. There is no other way to detect it. + // addSuppressed in this case as well. + case _: IllegalStateException => e.addSuppressed(cause) + } } else { e.addSuppressed(cause) } From 259860d23d1740954b739b639c5bdc3ede65ed25 Mon Sep 17 00:00:00 2001 From: ptkool Date: Mon, 1 May 2017 17:05:35 -0700 Subject: [PATCH 398/512] [SPARK-20463] Add support for IS [NOT] DISTINCT FROM. ## What changes were proposed in this pull request? Add support for the SQL standard distinct predicate to SPARK SQL. ``` IS [NOT] DISTINCT FROM ``` ## How was this patch tested? Tested using unit tests, integration tests, manual tests. Author: ptkool Closes #17764 from ptkool/is_not_distinct_from. --- .../antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 | 1 + .../org/apache/spark/sql/catalyst/parser/AstBuilder.scala | 5 +++++ .../spark/sql/catalyst/parser/ExpressionParserSuite.scala | 5 +++++ 3 files changed, 11 insertions(+) diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 1ecb3d1958f4..14c511f67060 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -534,6 +534,7 @@ predicate | NOT? kind=IN '(' query ')' | NOT? kind=(RLIKE | LIKE) pattern=valueExpression | IS NOT? kind=NULL + | IS NOT? kind=DISTINCT FROM right=valueExpression ; valueExpression diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index a48a693a95c9..d2a9b4a9a9f5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -935,6 +935,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { * - (NOT) LIKE * - (NOT) RLIKE * - IS (NOT) NULL. + * - IS (NOT) DISTINCT FROM */ private def withPredicate(e: Expression, ctx: PredicateContext): Expression = withOrigin(ctx) { // Invert a predicate if it has a valid NOT clause. @@ -962,6 +963,10 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { IsNotNull(e) case SqlBaseParser.NULL => IsNull(e) + case SqlBaseParser.DISTINCT if ctx.NOT != null => + EqualNullSafe(e, expression(ctx.right)) + case SqlBaseParser.DISTINCT => + Not(EqualNullSafe(e, expression(ctx.right))) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala index e7f3b64a7113..eb68eb9851b8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala @@ -167,6 +167,11 @@ class ExpressionParserSuite extends PlanTest { assertEqual("a = b is not null", ('a === 'b).isNotNull) } + test("is distinct expressions") { + assertEqual("a is distinct from b", !('a <=> 'b)) + assertEqual("a is not distinct from b", 'a <=> 'b) + } + test("binary arithmetic expressions") { // Simple operations assertEqual("a * b", 'a * 'b) From 943a684b9827ca294ed06a46431507538d40a134 Mon Sep 17 00:00:00 2001 From: Sameer Agarwal Date: Mon, 1 May 2017 17:42:53 -0700 Subject: [PATCH 399/512] [SPARK-20548] Disable ReplSuite.newProductSeqEncoder with REPL defined class ## What changes were proposed in this pull request? `newProductSeqEncoder with REPL defined class` in `ReplSuite` has been failing in-deterministically : https://spark-tests.appspot.com/failed-tests over the last few days. Disabling the test until a fix is in place. https://spark.test.databricks.com/job/spark-master-test-sbt-hadoop-2.7/176/testReport/junit/org.apache.spark.repl/ReplSuite/newProductSeqEncoder_with_REPL_defined_class/history/ ## How was this patch tested? N/A Author: Sameer Agarwal Closes #17823 from sameeragarwal/disable-test. --- .../src/test/scala/org/apache/spark/repl/ReplSuite.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala index 121a02a9be0a..8fe27080cac6 100644 --- a/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala +++ b/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala @@ -474,7 +474,8 @@ class ReplSuite extends SparkFunSuite { assertDoesNotContain("Exception", output) } - test("newProductSeqEncoder with REPL defined class") { + // TODO: [SPARK-20548] Fix and re-enable + ignore("newProductSeqEncoder with REPL defined class") { val output = runInterpreterInPasteMode("local-cluster[1,4,4096]", """ |case class Click(id: Int) From d20a976e8918ca8d607af452301e8014fe14e64a Mon Sep 17 00:00:00 2001 From: Felix Cheung Date: Mon, 1 May 2017 21:03:48 -0700 Subject: [PATCH 400/512] [SPARK-20192][SPARKR][DOC] SparkR migration guide to 2.2.0 ## What changes were proposed in this pull request? Updating R Programming Guide ## How was this patch tested? manually Author: Felix Cheung Closes #17816 from felixcheung/r22relnote. --- docs/sparkr.md | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/docs/sparkr.md b/docs/sparkr.md index 16b1ef651242..6dbd02a48890 100644 --- a/docs/sparkr.md +++ b/docs/sparkr.md @@ -644,3 +644,11 @@ You can inspect the search path in R with [`search()`](https://stat.ethz.ch/R-ma ## Upgrading to SparkR 2.1.0 - `join` no longer performs Cartesian Product by default, use `crossJoin` instead. + +## Upgrading to SparkR 2.2.0 + + - A `numPartitions` parameter has been added to `createDataFrame` and `as.DataFrame`. When splitting the data, the partition position calculation has been made to match the one in Scala. + - The method `createExternalTable` has been deprecated to be replaced by `createTable`. Either methods can be called to create external or managed table. Additional catalog methods have also been added. + - By default, derby.log is now saved to `tempdir()`. This will be created when instantiating the SparkSession with `enableHiveSupport` set to `TRUE`. + - `spark.lda` was not setting the optimizer correctly. It has been corrected. + - Several model summary outputs are updated to have `coefficients` as `matrix`. This includes `spark.logit`, `spark.kmeans`, `spark.glm`. Model summary outputs for `spark.gaussianMixture` have added log-likelihood as `loglik`. From 90d77e971f6b3fa268e411279f34bc1db4321991 Mon Sep 17 00:00:00 2001 From: zero323 Date: Mon, 1 May 2017 21:39:17 -0700 Subject: [PATCH 401/512] [SPARK-20532][SPARKR] Implement grouping and grouping_id ## What changes were proposed in this pull request? Adds R wrappers for: - `o.a.s.sql.functions.grouping` as `o.a.s.sql.functions.is_grouping` (to avoid shading `base::grouping` - `o.a.s.sql.functions.grouping_id` ## How was this patch tested? Existing unit tests, additional unit tests. `check-cran.sh`. Author: zero323 Closes #17807 from zero323/SPARK-20532. --- R/pkg/NAMESPACE | 2 + R/pkg/R/functions.R | 84 +++++++++++++++++++++++ R/pkg/R/generics.R | 8 +++ R/pkg/inst/tests/testthat/test_sparkSQL.R | 56 ++++++++++++++- 4 files changed, 148 insertions(+), 2 deletions(-) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index e8de34d9371a..7ecd168137e8 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -249,6 +249,8 @@ exportMethods("%<=>%", "getField", "getItem", "greatest", + "grouping_bit", + "grouping_id", "hex", "histogram", "hour", diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index f9687d680e7a..38384a89919a 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -3890,3 +3890,87 @@ setMethod("not", jc <- callJStatic("org.apache.spark.sql.functions", "not", x@jc) column(jc) }) + +#' grouping_bit +#' +#' Indicates whether a specified column in a GROUP BY list is aggregated or not, +#' returns 1 for aggregated or 0 for not aggregated in the result set. +#' +#' Same as \code{GROUPING} in SQL and \code{grouping} function in Scala. +#' +#' @param x Column to compute on +#' +#' @rdname grouping_bit +#' @name grouping_bit +#' @family agg_funcs +#' @aliases grouping_bit,Column-method +#' @export +#' @examples \dontrun{ +#' df <- createDataFrame(mtcars) +#' +#' # With cube +#' agg( +#' cube(df, "cyl", "gear", "am"), +#' mean(df$mpg), +#' grouping_bit(df$cyl), grouping_bit(df$gear), grouping_bit(df$am) +#' ) +#' +#' # With rollup +#' agg( +#' rollup(df, "cyl", "gear", "am"), +#' mean(df$mpg), +#' grouping_bit(df$cyl), grouping_bit(df$gear), grouping_bit(df$am) +#' ) +#' } +#' @note grouping_bit since 2.3.0 +setMethod("grouping_bit", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "grouping", x@jc) + column(jc) + }) + +#' grouping_id +#' +#' Returns the level of grouping. +#' +#' Equals to \code{ +#' grouping_bit(c1) * 2^(n - 1) + grouping_bit(c2) * 2^(n - 2) + ... + grouping_bit(cn) +#' } +#' +#' @param x Column to compute on +#' @param ... additional Column(s) (optional). +#' +#' @rdname grouping_id +#' @name grouping_id +#' @family agg_funcs +#' @aliases grouping_id,Column-method +#' @export +#' @examples \dontrun{ +#' df <- createDataFrame(mtcars) +#' +#' # With cube +#' agg( +#' cube(df, "cyl", "gear", "am"), +#' mean(df$mpg), +#' grouping_id(df$cyl, df$gear, df$am) +#' ) +#' +#' # With rollup +#' agg( +#' rollup(df, "cyl", "gear", "am"), +#' mean(df$mpg), +#' grouping_id(df$cyl, df$gear, df$am) +#' ) +#' } +#' @note grouping_id since 2.3.0 +setMethod("grouping_id", + signature(x = "Column"), + function(x, ...) { + jcols <- lapply(list(x, ...), function (x) { + stopifnot(class(x) == "Column") + x@jc + }) + jc <- callJStatic("org.apache.spark.sql.functions", "grouping_id", jcols) + column(jc) + }) diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index ef36765a7a72..e02d46426a5a 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -1052,6 +1052,14 @@ setGeneric("from_unixtime", function(x, ...) { standardGeneric("from_unixtime") #' @export setGeneric("greatest", function(x, ...) { standardGeneric("greatest") }) +#' @rdname grouping_bit +#' @export +setGeneric("grouping_bit", function(x) { standardGeneric("grouping_bit") }) + +#' @rdname grouping_id +#' @export +setGeneric("grouping_id", function(x, ...) { standardGeneric("grouping_id") }) + #' @rdname hex #' @export setGeneric("hex", function(x) { standardGeneric("hex") }) diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index 08296354ca7e..12867c15d1f9 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -1848,7 +1848,11 @@ test_that("test multi-dimensional aggregations with cube and rollup", { orderBy( agg( cube(df, "year", "department"), - expr("sum(salary) AS total_salary"), expr("avg(salary) AS average_salary") + expr("sum(salary) AS total_salary"), + expr("avg(salary) AS average_salary"), + alias(grouping_bit(df$year), "grouping_year"), + alias(grouping_bit(df$department), "grouping_department"), + alias(grouping_id(df$year, df$department), "grouping_id") ), "year", "department" ) @@ -1875,6 +1879,30 @@ test_that("test multi-dimensional aggregations with cube and rollup", { mean(c(21000, 32000, 22000)), # 2017 22000, 32000, 21000 # 2017 each department ), + grouping_year = c( + 1, # global + 1, 1, 1, # by department + 0, # 2016 + 0, 0, 0, # 2016 by department + 0, # 2017 + 0, 0, 0 # 2017 by department + ), + grouping_department = c( + 1, # global + 0, 0, 0, # by department + 1, # 2016 + 0, 0, 0, # 2016 by department + 1, # 2017 + 0, 0, 0 # 2017 by department + ), + grouping_id = c( + 3, # 11 + 2, 2, 2, # 10 + 1, # 01 + 0, 0, 0, # 00 + 1, # 01 + 0, 0, 0 # 00 + ), stringsAsFactors = FALSE ) @@ -1896,7 +1924,10 @@ test_that("test multi-dimensional aggregations with cube and rollup", { orderBy( agg( rollup(df, "year", "department"), - expr("sum(salary) AS total_salary"), expr("avg(salary) AS average_salary") + expr("sum(salary) AS total_salary"), expr("avg(salary) AS average_salary"), + alias(grouping_bit(df$year), "grouping_year"), + alias(grouping_bit(df$department), "grouping_department"), + alias(grouping_id(df$year, df$department), "grouping_id") ), "year", "department" ) @@ -1920,6 +1951,27 @@ test_that("test multi-dimensional aggregations with cube and rollup", { mean(c(21000, 32000, 22000)), # 2017 22000, 32000, 21000 # 2017 each department ), + grouping_year = c( + 1, # global + 0, # 2016 + 0, 0, 0, # 2016 each department + 0, # 2017 + 0, 0, 0 # 2017 each department + ), + grouping_department = c( + 1, # global + 1, # 2016 + 0, 0, 0, # 2016 each department + 1, # 2017 + 0, 0, 0 # 2017 each department + ), + grouping_id = c( + 3, # 11 + 1, # 01 + 0, 0, 0, # 00 + 1, # 01 + 0, 0, 0 # 00 + ), stringsAsFactors = FALSE ) From afb21bf22a59c9416c04637412fb69d1442e6826 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Tue, 2 May 2017 13:56:41 +0800 Subject: [PATCH 402/512] [SPARK-20537][CORE] Fixing OffHeapColumnVector reallocation ## What changes were proposed in this pull request? As #17773 revealed `OnHeapColumnVector` may copy a part of the original storage. `OffHeapColumnVector` reallocation also copies to the new storage data up to 'elementsAppended'. This variable is only updated when using the `ColumnVector.appendX` API, while `ColumnVector.putX` is more commonly used. This PR copies the new storage data up to the previously-allocated size in`OffHeapColumnVector`. ## How was this patch tested? Existing test suites Author: Kazuaki Ishizaki Closes #17811 from kiszk/SPARK-20537. --- .../vectorized/OffHeapColumnVector.java | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java index e988c0722bd7..a7d3744d00e9 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java @@ -436,28 +436,29 @@ public void loadBytes(ColumnVector.Array array) { // Split out the slow path. @Override protected void reserveInternal(int newCapacity) { + int oldCapacity = (this.data == 0L) ? 0 : capacity; if (this.resultArray != null) { this.lengthData = - Platform.reallocateMemory(lengthData, elementsAppended * 4, newCapacity * 4); + Platform.reallocateMemory(lengthData, oldCapacity * 4, newCapacity * 4); this.offsetData = - Platform.reallocateMemory(offsetData, elementsAppended * 4, newCapacity * 4); + Platform.reallocateMemory(offsetData, oldCapacity * 4, newCapacity * 4); } else if (type instanceof ByteType || type instanceof BooleanType) { - this.data = Platform.reallocateMemory(data, elementsAppended, newCapacity); + this.data = Platform.reallocateMemory(data, oldCapacity, newCapacity); } else if (type instanceof ShortType) { - this.data = Platform.reallocateMemory(data, elementsAppended * 2, newCapacity * 2); + this.data = Platform.reallocateMemory(data, oldCapacity * 2, newCapacity * 2); } else if (type instanceof IntegerType || type instanceof FloatType || type instanceof DateType || DecimalType.is32BitDecimalType(type)) { - this.data = Platform.reallocateMemory(data, elementsAppended * 4, newCapacity * 4); + this.data = Platform.reallocateMemory(data, oldCapacity * 4, newCapacity * 4); } else if (type instanceof LongType || type instanceof DoubleType || DecimalType.is64BitDecimalType(type) || type instanceof TimestampType) { - this.data = Platform.reallocateMemory(data, elementsAppended * 8, newCapacity * 8); + this.data = Platform.reallocateMemory(data, oldCapacity * 8, newCapacity * 8); } else if (resultStruct != null) { // Nothing to store. } else { throw new RuntimeException("Unhandled " + type); } - this.nulls = Platform.reallocateMemory(nulls, elementsAppended, newCapacity); - Platform.setMemory(nulls + elementsAppended, (byte)0, newCapacity - elementsAppended); + this.nulls = Platform.reallocateMemory(nulls, oldCapacity, newCapacity); + Platform.setMemory(nulls + oldCapacity, (byte)0, newCapacity - oldCapacity); capacity = newCapacity; } } From 86174ea89b39a300caaba6baffac70f3dc702788 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Tue, 2 May 2017 14:08:16 +0800 Subject: [PATCH 403/512] [SPARK-20549] java.io.CharConversionException: Invalid UTF-32' in JsonToStructs ## What changes were proposed in this pull request? A fix for the same problem was made in #17693 but ignored `JsonToStructs`. This PR uses the same fix for `JsonToStructs`. ## How was this patch tested? Regression test Author: Burak Yavuz Closes #17826 from brkyvz/SPARK-20549. --- .../spark/sql/catalyst/expressions/jsonExpressions.scala | 8 +++----- .../spark/sql/catalyst/json/CreateJacksonParser.scala | 7 +++++-- .../sql/catalyst/expressions/JsonExpressionsSuite.scala | 7 +++++++ 3 files changed, 15 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index 9fb0ea68153d..6b90354367f4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -151,8 +151,7 @@ case class GetJsonObject(json: Expression, path: Expression) try { /* We know the bytes are UTF-8 encoded. Pass a Reader to avoid having Jackson detect character encoding which could fail for some malformed strings */ - Utils.tryWithResource(jsonFactory.createParser(new InputStreamReader( - new ByteArrayInputStream(jsonStr.getBytes), "UTF-8"))) { parser => + Utils.tryWithResource(CreateJacksonParser.utf8String(jsonFactory, jsonStr)) { parser => val output = new ByteArrayOutputStream() val matched = Utils.tryWithResource( jsonFactory.createGenerator(output, JsonEncoding.UTF8)) { generator => @@ -398,9 +397,8 @@ case class JsonTuple(children: Seq[Expression]) try { /* We know the bytes are UTF-8 encoded. Pass a Reader to avoid having Jackson detect character encoding which could fail for some malformed strings */ - Utils.tryWithResource(jsonFactory.createParser(new InputStreamReader( - new ByteArrayInputStream(json.getBytes), "UTF-8"))) { - parser => parseRow(parser, input) + Utils.tryWithResource(CreateJacksonParser.utf8String(jsonFactory, json)) { parser => + parseRow(parser, input) } } catch { case _: JsonProcessingException => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/CreateJacksonParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/CreateJacksonParser.scala index e0ed03a68981..025a388aacaa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/CreateJacksonParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/CreateJacksonParser.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.json -import java.io.InputStream +import java.io.{ByteArrayInputStream, InputStream, InputStreamReader} import com.fasterxml.jackson.core.{JsonFactory, JsonParser} import org.apache.hadoop.io.Text @@ -33,7 +33,10 @@ private[sql] object CreateJacksonParser extends Serializable { val bb = record.getByteBuffer assert(bb.hasArray) - jsonFactory.createParser(bb.array(), bb.arrayOffset() + bb.position(), bb.remaining()) + val bain = new ByteArrayInputStream( + bb.array(), bb.arrayOffset() + bb.position(), bb.remaining()) + + jsonFactory.createParser(new InputStreamReader(bain, "UTF-8")) } def text(jsonFactory: JsonFactory, record: Text): JsonParser = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala index 4402ad4e9a9e..65d5c3a582b1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala @@ -453,6 +453,13 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { ) } + test("SPARK-20549: from_json bad UTF-8") { + val schema = StructType(StructField("a", IntegerType) :: Nil) + checkEvaluation( + JsonToStructs(schema, Map.empty, Literal(badJson), gmtId), + null) + } + test("from_json with timestamp") { val schema = StructType(StructField("t", TimestampType) :: Nil) From e300a5a145820ecd466885c73245d6684e8cb0aa Mon Sep 17 00:00:00 2001 From: Nick Pentreath Date: Tue, 2 May 2017 10:49:13 +0200 Subject: [PATCH 404/512] [SPARK-20300][ML][PYSPARK] Python API for ALSModel.recommendForAllUsers,Items Add Python API for `ALSModel` methods `recommendForAllUsers`, `recommendForAllItems` ## How was this patch tested? New doc tests. Author: Nick Pentreath Closes #17622 from MLnick/SPARK-20300-pyspark-recall. --- python/pyspark/ml/recommendation.py | 30 +++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/python/pyspark/ml/recommendation.py b/python/pyspark/ml/recommendation.py index 8bc899a0788b..bcfb36880eb0 100644 --- a/python/pyspark/ml/recommendation.py +++ b/python/pyspark/ml/recommendation.py @@ -82,6 +82,14 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha Row(user=1, item=0, prediction=2.6258413791656494) >>> predictions[2] Row(user=2, item=0, prediction=-1.5018409490585327) + >>> user_recs = model.recommendForAllUsers(3) + >>> user_recs.where(user_recs.user == 0)\ + .select("recommendations.item", "recommendations.rating").collect() + [Row(item=[0, 1, 2], rating=[3.910..., 1.992..., -0.138...])] + >>> item_recs = model.recommendForAllItems(3) + >>> item_recs.where(item_recs.item == 2)\ + .select("recommendations.user", "recommendations.rating").collect() + [Row(user=[2, 1, 0], rating=[4.901..., 3.981..., -0.138...])] >>> als_path = temp_path + "/als" >>> als.save(als_path) >>> als2 = ALS.load(als_path) @@ -384,6 +392,28 @@ def itemFactors(self): """ return self._call_java("itemFactors") + @since("2.2.0") + def recommendForAllUsers(self, numItems): + """ + Returns top `numItems` items recommended for each user, for all users. + + :param numItems: max number of recommendations for each user + :return: a DataFrame of (userCol, recommendations), where recommendations are + stored as an array of (itemCol, rating) Rows. + """ + return self._call_java("recommendForAllUsers", numItems) + + @since("2.2.0") + def recommendForAllItems(self, numUsers): + """ + Returns top `numUsers` users recommended for each item, for all items. + + :param numUsers: max number of recommendations for each item + :return: a DataFrame of (itemCol, recommendations), where recommendations are + stored as an array of (userCol, rating) Rows. + """ + return self._call_java("recommendForAllItems", numUsers) + if __name__ == "__main__": import doctest From b1e639ab09d3a7a1545119e45a505c9a04308353 Mon Sep 17 00:00:00 2001 From: Xiao Li Date: Tue, 2 May 2017 16:49:24 +0800 Subject: [PATCH 405/512] [SPARK-19235][SQL][TEST][FOLLOW-UP] Enable Test Cases in DDLSuite with Hive Metastore ### What changes were proposed in this pull request? This is a follow-up of enabling test cases in DDLSuite with Hive Metastore. It consists of the following remaining tasks: - Run all the `alter table` and `drop table` DDL tests against data source tables when using Hive metastore. - Do not run any `alter table` and `drop table` DDL test against Hive serde tables when using InMemoryCatalog. - Reenable `alter table: set serde partition` and `alter table: set serde` tests for Hive serde tables. ### How was this patch tested? N/A Author: Xiao Li Closes #17524 from gatorsmile/cleanupDDLSuite. --- .../sql/execution/command/DDLSuite.scala | 291 ++++++++---------- .../apache/spark/sql/test/SQLTestUtils.scala | 3 +- .../sql/hive/execution/HiveDDLSuite.scala | 73 ++++- 3 files changed, 195 insertions(+), 172 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index 2f4eb1b15519..0abcff76060f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -49,7 +49,8 @@ class InMemoryCatalogedDDLSuite extends DDLSuite with SharedSQLContext with Befo protected override def generateTable( catalog: SessionCatalog, - name: TableIdentifier): CatalogTable = { + name: TableIdentifier, + isDataSource: Boolean = true): CatalogTable = { val storage = CatalogStorageFormat.empty.copy(locationUri = Some(catalog.defaultTablePath(name))) val metadata = new MetadataBuilder() @@ -70,46 +71,6 @@ class InMemoryCatalogedDDLSuite extends DDLSuite with SharedSQLContext with Befo tracksPartitionsInCatalog = true) } - test("alter table: set location (datasource table)") { - testSetLocation(isDatasourceTable = true) - } - - test("alter table: set properties (datasource table)") { - testSetProperties(isDatasourceTable = true) - } - - test("alter table: unset properties (datasource table)") { - testUnsetProperties(isDatasourceTable = true) - } - - test("alter table: set serde (datasource table)") { - testSetSerde(isDatasourceTable = true) - } - - test("alter table: set serde partition (datasource table)") { - testSetSerdePartition(isDatasourceTable = true) - } - - test("alter table: change column (datasource table)") { - testChangeColumn(isDatasourceTable = true) - } - - test("alter table: add partition (datasource table)") { - testAddPartitions(isDatasourceTable = true) - } - - test("alter table: drop partition (datasource table)") { - testDropPartitions(isDatasourceTable = true) - } - - test("alter table: rename partition (datasource table)") { - testRenamePartitions(isDatasourceTable = true) - } - - test("drop table - data source table") { - testDropTable(isDatasourceTable = true) - } - test("create a managed Hive source table") { assume(spark.sparkContext.conf.get(CATALOG_IMPLEMENTATION) == "in-memory") val tabName = "tbl" @@ -163,7 +124,10 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { spark.sparkContext.conf.get(CATALOG_IMPLEMENTATION) == "hive" } - protected def generateTable(catalog: SessionCatalog, name: TableIdentifier): CatalogTable + protected def generateTable( + catalog: SessionCatalog, + name: TableIdentifier, + isDataSource: Boolean = true): CatalogTable private val escapedIdentifier = "`(.+)`".r @@ -205,8 +169,11 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { ignoreIfExists = false) } - private def createTable(catalog: SessionCatalog, name: TableIdentifier): Unit = { - catalog.createTable(generateTable(catalog, name), ignoreIfExists = false) + private def createTable( + catalog: SessionCatalog, + name: TableIdentifier, + isDataSource: Boolean = true): Unit = { + catalog.createTable(generateTable(catalog, name, isDataSource), ignoreIfExists = false) } private def createTablePartition( @@ -223,6 +190,46 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { new Path(CatalogUtils.URIToString(warehousePath), s"$dbName.db").toUri } + test("alter table: set location (datasource table)") { + testSetLocation(isDatasourceTable = true) + } + + test("alter table: set properties (datasource table)") { + testSetProperties(isDatasourceTable = true) + } + + test("alter table: unset properties (datasource table)") { + testUnsetProperties(isDatasourceTable = true) + } + + test("alter table: set serde (datasource table)") { + testSetSerde(isDatasourceTable = true) + } + + test("alter table: set serde partition (datasource table)") { + testSetSerdePartition(isDatasourceTable = true) + } + + test("alter table: change column (datasource table)") { + testChangeColumn(isDatasourceTable = true) + } + + test("alter table: add partition (datasource table)") { + testAddPartitions(isDatasourceTable = true) + } + + test("alter table: drop partition (datasource table)") { + testDropPartitions(isDatasourceTable = true) + } + + test("alter table: rename partition (datasource table)") { + testRenamePartitions(isDatasourceTable = true) + } + + test("drop table - data source table") { + testDropTable(isDatasourceTable = true) + } + test("the qualified path of a database is stored in the catalog") { val catalog = spark.sessionState.catalog @@ -835,32 +842,6 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } } - test("alter table: set location") { - testSetLocation(isDatasourceTable = false) - } - - test("alter table: set properties") { - testSetProperties(isDatasourceTable = false) - } - - test("alter table: unset properties") { - testUnsetProperties(isDatasourceTable = false) - } - - // TODO: move this test to HiveDDLSuite.scala - ignore("alter table: set serde") { - testSetSerde(isDatasourceTable = false) - } - - // TODO: move this test to HiveDDLSuite.scala - ignore("alter table: set serde partition") { - testSetSerdePartition(isDatasourceTable = false) - } - - test("alter table: change column") { - testChangeColumn(isDatasourceTable = false) - } - test("alter table: bucketing is not supported") { val catalog = spark.sessionState.catalog val tableIdent = TableIdentifier("tab1", Some("dbx")) @@ -885,10 +866,6 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { assertUnsupported("ALTER TABLE dbx.tab1 NOT STORED AS DIRECTORIES") } - test("alter table: add partition") { - testAddPartitions(isDatasourceTable = false) - } - test("alter table: recover partitions (sequential)") { withSQLConf("spark.rdd.parallelListingThreshold" -> "10") { testRecoverPartitions() @@ -957,17 +934,10 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { assertUnsupported("ALTER VIEW dbx.tab1 ADD IF NOT EXISTS PARTITION (b='2')") } - test("alter table: drop partition") { - testDropPartitions(isDatasourceTable = false) - } - test("alter table: drop partition is not supported for views") { assertUnsupported("ALTER VIEW dbx.tab1 DROP IF EXISTS PARTITION (b='2')") } - test("alter table: rename partition") { - testRenamePartitions(isDatasourceTable = false) - } test("show databases") { sql("CREATE DATABASE showdb2B") @@ -1011,18 +981,14 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { assert(catalog.listTables("default") == Nil) } - test("drop table") { - testDropTable(isDatasourceTable = false) - } - protected def testDropTable(isDatasourceTable: Boolean): Unit = { + if (!isUsingHiveMetastore) { + assert(isDatasourceTable, "InMemoryCatalog only supports data source tables") + } val catalog = spark.sessionState.catalog val tableIdent = TableIdentifier("tab1", Some("dbx")) createDatabase(catalog, "dbx") - createTable(catalog, tableIdent) - if (isDatasourceTable) { - convertToDatasourceTable(catalog, tableIdent) - } + createTable(catalog, tableIdent, isDatasourceTable) assert(catalog.listTables("dbx") == Seq(tableIdent)) sql("DROP TABLE dbx.tab1") assert(catalog.listTables("dbx") == Nil) @@ -1046,22 +1012,14 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { e.getMessage.contains("Cannot drop a table with DROP VIEW. Please use DROP TABLE instead")) } - private def convertToDatasourceTable( - catalog: SessionCatalog, - tableIdent: TableIdentifier): Unit = { - catalog.alterTable(catalog.getTableMetadata(tableIdent).copy( - provider = Some("csv"))) - assert(catalog.getTableMetadata(tableIdent).provider == Some("csv")) - } - protected def testSetProperties(isDatasourceTable: Boolean): Unit = { + if (!isUsingHiveMetastore) { + assert(isDatasourceTable, "InMemoryCatalog only supports data source tables") + } val catalog = spark.sessionState.catalog val tableIdent = TableIdentifier("tab1", Some("dbx")) createDatabase(catalog, "dbx") - createTable(catalog, tableIdent) - if (isDatasourceTable) { - convertToDatasourceTable(catalog, tableIdent) - } + createTable(catalog, tableIdent, isDatasourceTable) def getProps: Map[String, String] = { if (isUsingHiveMetastore) { normalizeCatalogTable(catalog.getTableMetadata(tableIdent)).properties @@ -1084,13 +1042,13 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } protected def testUnsetProperties(isDatasourceTable: Boolean): Unit = { + if (!isUsingHiveMetastore) { + assert(isDatasourceTable, "InMemoryCatalog only supports data source tables") + } val catalog = spark.sessionState.catalog val tableIdent = TableIdentifier("tab1", Some("dbx")) createDatabase(catalog, "dbx") - createTable(catalog, tableIdent) - if (isDatasourceTable) { - convertToDatasourceTable(catalog, tableIdent) - } + createTable(catalog, tableIdent, isDatasourceTable) def getProps: Map[String, String] = { if (isUsingHiveMetastore) { normalizeCatalogTable(catalog.getTableMetadata(tableIdent)).properties @@ -1121,15 +1079,15 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } protected def testSetLocation(isDatasourceTable: Boolean): Unit = { + if (!isUsingHiveMetastore) { + assert(isDatasourceTable, "InMemoryCatalog only supports data source tables") + } val catalog = spark.sessionState.catalog val tableIdent = TableIdentifier("tab1", Some("dbx")) val partSpec = Map("a" -> "1", "b" -> "2") createDatabase(catalog, "dbx") - createTable(catalog, tableIdent) + createTable(catalog, tableIdent, isDatasourceTable) createTablePartition(catalog, partSpec, tableIdent) - if (isDatasourceTable) { - convertToDatasourceTable(catalog, tableIdent) - } assert(catalog.getTableMetadata(tableIdent).storage.locationUri.isDefined) assert(normalizeSerdeProp(catalog.getTableMetadata(tableIdent).storage.properties).isEmpty) assert(catalog.getPartition(tableIdent, partSpec).storage.locationUri.isDefined) @@ -1171,13 +1129,13 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } protected def testSetSerde(isDatasourceTable: Boolean): Unit = { + if (!isUsingHiveMetastore) { + assert(isDatasourceTable, "InMemoryCatalog only supports data source tables") + } val catalog = spark.sessionState.catalog val tableIdent = TableIdentifier("tab1", Some("dbx")) createDatabase(catalog, "dbx") - createTable(catalog, tableIdent) - if (isDatasourceTable) { - convertToDatasourceTable(catalog, tableIdent) - } + createTable(catalog, tableIdent, isDatasourceTable) def checkSerdeProps(expectedSerdeProps: Map[String, String]): Unit = { val serdeProp = catalog.getTableMetadata(tableIdent).storage.properties if (isUsingHiveMetastore) { @@ -1187,8 +1145,12 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } } if (isUsingHiveMetastore) { - assert(catalog.getTableMetadata(tableIdent).storage.serde == - Some("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe")) + val expectedSerde = if (isDatasourceTable) { + "org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe" + } else { + "org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe" + } + assert(catalog.getTableMetadata(tableIdent).storage.serde == Some(expectedSerde)) } else { assert(catalog.getTableMetadata(tableIdent).storage.serde.isEmpty) } @@ -1229,18 +1191,18 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } protected def testSetSerdePartition(isDatasourceTable: Boolean): Unit = { + if (!isUsingHiveMetastore) { + assert(isDatasourceTable, "InMemoryCatalog only supports data source tables") + } val catalog = spark.sessionState.catalog val tableIdent = TableIdentifier("tab1", Some("dbx")) val spec = Map("a" -> "1", "b" -> "2") createDatabase(catalog, "dbx") - createTable(catalog, tableIdent) + createTable(catalog, tableIdent, isDatasourceTable) createTablePartition(catalog, spec, tableIdent) createTablePartition(catalog, Map("a" -> "1", "b" -> "3"), tableIdent) createTablePartition(catalog, Map("a" -> "2", "b" -> "2"), tableIdent) createTablePartition(catalog, Map("a" -> "2", "b" -> "3"), tableIdent) - if (isDatasourceTable) { - convertToDatasourceTable(catalog, tableIdent) - } def checkPartitionSerdeProps(expectedSerdeProps: Map[String, String]): Unit = { val serdeProp = catalog.getPartition(tableIdent, spec).storage.properties if (isUsingHiveMetastore) { @@ -1250,8 +1212,12 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } } if (isUsingHiveMetastore) { - assert(catalog.getPartition(tableIdent, spec).storage.serde == - Some("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe")) + val expectedSerde = if (isDatasourceTable) { + "org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe" + } else { + "org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe" + } + assert(catalog.getPartition(tableIdent, spec).storage.serde == Some(expectedSerde)) } else { assert(catalog.getPartition(tableIdent, spec).storage.serde.isEmpty) } @@ -1295,6 +1261,9 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } protected def testAddPartitions(isDatasourceTable: Boolean): Unit = { + if (!isUsingHiveMetastore) { + assert(isDatasourceTable, "InMemoryCatalog only supports data source tables") + } val catalog = spark.sessionState.catalog val tableIdent = TableIdentifier("tab1", Some("dbx")) val part1 = Map("a" -> "1", "b" -> "5") @@ -1303,11 +1272,8 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { val part4 = Map("a" -> "4", "b" -> "8") val part5 = Map("a" -> "9", "b" -> "9") createDatabase(catalog, "dbx") - createTable(catalog, tableIdent) + createTable(catalog, tableIdent, isDatasourceTable) createTablePartition(catalog, part1, tableIdent) - if (isDatasourceTable) { - convertToDatasourceTable(catalog, tableIdent) - } assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == Set(part1)) // basic add partition @@ -1354,6 +1320,9 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } protected def testDropPartitions(isDatasourceTable: Boolean): Unit = { + if (!isUsingHiveMetastore) { + assert(isDatasourceTable, "InMemoryCatalog only supports data source tables") + } val catalog = spark.sessionState.catalog val tableIdent = TableIdentifier("tab1", Some("dbx")) val part1 = Map("a" -> "1", "b" -> "5") @@ -1362,7 +1331,7 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { val part4 = Map("a" -> "4", "b" -> "8") val part5 = Map("a" -> "9", "b" -> "9") createDatabase(catalog, "dbx") - createTable(catalog, tableIdent) + createTable(catalog, tableIdent, isDatasourceTable) createTablePartition(catalog, part1, tableIdent) createTablePartition(catalog, part2, tableIdent) createTablePartition(catalog, part3, tableIdent) @@ -1370,9 +1339,6 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { createTablePartition(catalog, part5, tableIdent) assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == Set(part1, part2, part3, part4, part5)) - if (isDatasourceTable) { - convertToDatasourceTable(catalog, tableIdent) - } // basic drop partition sql("ALTER TABLE dbx.tab1 DROP IF EXISTS PARTITION (a='4', b='8'), PARTITION (a='3', b='7')") @@ -1407,20 +1373,20 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } protected def testRenamePartitions(isDatasourceTable: Boolean): Unit = { + if (!isUsingHiveMetastore) { + assert(isDatasourceTable, "InMemoryCatalog only supports data source tables") + } val catalog = spark.sessionState.catalog val tableIdent = TableIdentifier("tab1", Some("dbx")) val part1 = Map("a" -> "1", "b" -> "q") val part2 = Map("a" -> "2", "b" -> "c") val part3 = Map("a" -> "3", "b" -> "p") createDatabase(catalog, "dbx") - createTable(catalog, tableIdent) + createTable(catalog, tableIdent, isDatasourceTable) createTablePartition(catalog, part1, tableIdent) createTablePartition(catalog, part2, tableIdent) createTablePartition(catalog, part3, tableIdent) assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == Set(part1, part2, part3)) - if (isDatasourceTable) { - convertToDatasourceTable(catalog, tableIdent) - } // basic rename partition sql("ALTER TABLE dbx.tab1 PARTITION (a='1', b='q') RENAME TO PARTITION (a='100', b='p')") @@ -1451,14 +1417,14 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } protected def testChangeColumn(isDatasourceTable: Boolean): Unit = { + if (!isUsingHiveMetastore) { + assert(isDatasourceTable, "InMemoryCatalog only supports data source tables") + } val catalog = spark.sessionState.catalog val resolver = spark.sessionState.conf.resolver val tableIdent = TableIdentifier("tab1", Some("dbx")) createDatabase(catalog, "dbx") - createTable(catalog, tableIdent) - if (isDatasourceTable) { - convertToDatasourceTable(catalog, tableIdent) - } + createTable(catalog, tableIdent, isDatasourceTable) def getMetadata(colName: String): Metadata = { val column = catalog.getTableMetadata(tableIdent).schema.fields.find { field => resolver(field.name, colName) @@ -1601,13 +1567,15 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } test("drop current database") { - sql("CREATE DATABASE temp") - sql("USE temp") - sql("DROP DATABASE temp") - val e = intercept[AnalysisException] { + withDatabase("temp") { + sql("CREATE DATABASE temp") + sql("USE temp") + sql("DROP DATABASE temp") + val e = intercept[AnalysisException] { sql("CREATE TABLE t (a INT, b INT) USING parquet") }.getMessage - assert(e.contains("Database 'temp' not found")) + assert(e.contains("Database 'temp' not found")) + } } test("drop default database") { @@ -1837,22 +1805,25 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { checkAnswer(spark.table("tbl"), Row(1)) val defaultTablePath = spark.sessionState.catalog .getTableMetadata(TableIdentifier("tbl")).storage.locationUri.get - - sql(s"ALTER TABLE tbl SET LOCATION '${dir.toURI}'") - spark.catalog.refreshTable("tbl") - // SET LOCATION won't move data from previous table path to new table path. - assert(spark.table("tbl").count() == 0) - // the previous table path should be still there. - assert(new File(defaultTablePath).exists()) - - sql("INSERT INTO tbl SELECT 2") - checkAnswer(spark.table("tbl"), Row(2)) - // newly inserted data will go to the new table path. - assert(dir.listFiles().nonEmpty) - - sql("DROP TABLE tbl") - // the new table path will be removed after DROP TABLE. - assert(!dir.exists()) + try { + sql(s"ALTER TABLE tbl SET LOCATION '${dir.toURI}'") + spark.catalog.refreshTable("tbl") + // SET LOCATION won't move data from previous table path to new table path. + assert(spark.table("tbl").count() == 0) + // the previous table path should be still there. + assert(new File(defaultTablePath).exists()) + + sql("INSERT INTO tbl SELECT 2") + checkAnswer(spark.table("tbl"), Row(2)) + // newly inserted data will go to the new table path. + assert(dir.listFiles().nonEmpty) + + sql("DROP TABLE tbl") + // the new table path will be removed after DROP TABLE. + assert(!dir.exists()) + } finally { + Utils.deleteRecursively(new File(defaultTablePath)) + } } } } @@ -2125,7 +2096,7 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { Seq("a b", "a:b", "a%b").foreach { specialChars => test(s"location uri contains $specialChars for database") { - try { + withDatabase ("tmpdb") { withTable("t") { withTempDir { dir => val loc = new File(dir, specialChars) @@ -2140,8 +2111,6 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { assert(tblloc.listFiles().nonEmpty) } } - } finally { - spark.sql("DROP DATABASE IF EXISTS tmpdb") } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index 44c0fc70d066..f6d47734d7e8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -237,7 +237,7 @@ private[sql] trait SQLTestUtils try f(dbName) finally { if (spark.catalog.currentDatabase == dbName) { - spark.sql(s"USE ${DEFAULT_DATABASE}") + spark.sql(s"USE $DEFAULT_DATABASE") } spark.sql(s"DROP DATABASE $dbName CASCADE") } @@ -251,6 +251,7 @@ private[sql] trait SQLTestUtils dbNames.foreach { name => spark.sql(s"DROP DATABASE IF EXISTS $name") } + spark.sql(s"USE $DEFAULT_DATABASE") } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala index 16a99321bad3..341e03b5e57f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.execution.command.{DDLSuite, DDLUtils} import org.apache.spark.sql.hive.HiveExternalCatalog import org.apache.spark.sql.hive.orc.OrcFileOperator import org.apache.spark.sql.hive.test.TestHiveSingleton -import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.{HiveSerDe, SQLConf} import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ @@ -50,15 +50,28 @@ class HiveCatalogedDDLSuite extends DDLSuite with TestHiveSingleton with BeforeA protected override def generateTable( catalog: SessionCatalog, - name: TableIdentifier): CatalogTable = { + name: TableIdentifier, + isDataSource: Boolean): CatalogTable = { val storage = - CatalogStorageFormat( - locationUri = Some(catalog.defaultTablePath(name)), - inputFormat = Some("org.apache.hadoop.mapred.SequenceFileInputFormat"), - outputFormat = Some("org.apache.hadoop.hive.ql.io.HiveSequenceFileOutputFormat"), - serde = Some("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe"), - compressed = false, - properties = Map("serialization.format" -> "1")) + if (isDataSource) { + val serde = HiveSerDe.sourceToSerDe("parquet") + assert(serde.isDefined, "The default format is not Hive compatible") + CatalogStorageFormat( + locationUri = Some(catalog.defaultTablePath(name)), + inputFormat = serde.get.inputFormat, + outputFormat = serde.get.outputFormat, + serde = serde.get.serde, + compressed = false, + properties = Map("serialization.format" -> "1")) + } else { + CatalogStorageFormat( + locationUri = Some(catalog.defaultTablePath(name)), + inputFormat = Some("org.apache.hadoop.mapred.SequenceFileInputFormat"), + outputFormat = Some("org.apache.hadoop.hive.ql.io.HiveSequenceFileOutputFormat"), + serde = Some("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe"), + compressed = false, + properties = Map("serialization.format" -> "1")) + } val metadata = new MetadataBuilder() .putString("key", "value") .build() @@ -71,7 +84,7 @@ class HiveCatalogedDDLSuite extends DDLSuite with TestHiveSingleton with BeforeA .add("col2", "string") .add("a", "int") .add("b", "int"), - provider = Some("hive"), + provider = if (isDataSource) Some("parquet") else Some("hive"), partitionColumnNames = Seq("a", "b"), createTime = 0L, tracksPartitionsInCatalog = true) @@ -107,6 +120,46 @@ class HiveCatalogedDDLSuite extends DDLSuite with TestHiveSingleton with BeforeA ) } + test("alter table: set location") { + testSetLocation(isDatasourceTable = false) + } + + test("alter table: set properties") { + testSetProperties(isDatasourceTable = false) + } + + test("alter table: unset properties") { + testUnsetProperties(isDatasourceTable = false) + } + + test("alter table: set serde") { + testSetSerde(isDatasourceTable = false) + } + + test("alter table: set serde partition") { + testSetSerdePartition(isDatasourceTable = false) + } + + test("alter table: change column") { + testChangeColumn(isDatasourceTable = false) + } + + test("alter table: rename partition") { + testRenamePartitions(isDatasourceTable = false) + } + + test("alter table: drop partition") { + testDropPartitions(isDatasourceTable = false) + } + + test("alter table: add partition") { + testAddPartitions(isDatasourceTable = false) + } + + test("drop table") { + testDropTable(isDatasourceTable = false) + } + } class HiveDDLSuite From 13f47dc5033a99df8d9ec18f2ce373119462f7bc Mon Sep 17 00:00:00 2001 From: Felix Cheung Date: Tue, 2 May 2017 09:37:01 -0700 Subject: [PATCH 406/512] [SPARK-20490][SPARKR][DOC] add family tag for not function ## What changes were proposed in this pull request? doc only ## How was this patch tested? manual Author: Felix Cheung Closes #17828 from felixcheung/rnotfamily. --- R/pkg/R/functions.R | 1 + 1 file changed, 1 insertion(+) diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 38384a89919a..3d47b09ce551 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -3871,6 +3871,7 @@ setMethod("posexplode_outer", #' @rdname not #' @name not #' @aliases not,Column-method +#' @family normal_funcs #' @export #' @examples \dontrun{ #' df <- createDataFrame(data.frame( From ef3df9125a30f8fb817fe855b74d7130be45b0ee Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Tue, 2 May 2017 14:30:06 -0700 Subject: [PATCH 407/512] [SPARK-20421][CORE] Add a missing deprecation tag. In the previous patch I deprecated StorageStatus, but not the method in SparkContext that exposes that class publicly. So deprecate the method too. Author: Marcelo Vanzin Closes #17824 from vanzin/SPARK-20421. --- core/src/main/scala/org/apache/spark/SparkContext.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 0ec1bdd39b2f..f7c32e5f0cec 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -1734,6 +1734,7 @@ class SparkContext(config: SparkConf) extends Logging { * Return information about blocks stored in all of the slaves */ @DeveloperApi + @deprecated("This method may change or be removed in a future release.", "2.2.0") def getExecutorStorageStatus: Array[StorageStatus] = { assertNotStopped() env.blockManager.master.getStorageStatus From b946f3160eb7953fb30edf1f097ea87be75b33e7 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 3 May 2017 10:08:46 +0800 Subject: [PATCH 408/512] [SPARK-20558][CORE] clear InheritableThreadLocal variables in SparkContext when stopping it ## What changes were proposed in this pull request? To better understand this problem, let's take a look at an example first: ``` object Main { def main(args: Array[String]): Unit = { var t = new Test new Thread(new Runnable { override def run() = {} }).start() println("first thread finished") t.a = null t = new Test new Thread(new Runnable { override def run() = {} }).start() } } class Test { var a = new InheritableThreadLocal[String] { override protected def childValue(parent: String): String = { println("parent value is: " + parent) parent } } a.set("hello") } ``` The result is: ``` parent value is: hello first thread finished parent value is: hello parent value is: hello ``` Once an `InheritableThreadLocal` has been set value, child threads will inherit its value as long as it has not been GCed, so setting the variable which holds the `InheritableThreadLocal` to `null` doesn't work as we expected. In `SparkContext`, we have an `InheritableThreadLocal` for local properties, we should clear it when stopping `SparkContext`, or all the future child threads will still inherit it and copy the properties and waste memory. This is the root cause of https://issues.apache.org/jira/browse/SPARK-20548 , which creates/stops `SparkContext` many times and finally have a lot of `InheritableThreadLocal` alive, and cause OOM when starting new threads in the internal thread pools. ## How was this patch tested? N/A Author: Wenchen Fan Closes #17833 from cloud-fan/core. --- core/src/main/scala/org/apache/spark/SparkContext.scala | 3 +++ 1 file changed, 3 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index f7c32e5f0cec..7dbceb9c5c1a 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -1939,6 +1939,9 @@ class SparkContext(config: SparkConf) extends Logging { } SparkEnv.set(null) } + // Clear this `InheritableThreadLocal`, or it will still be inherited in child threads even this + // `SparkContext` is stopped. + localProperties.remove() // Unset YARN mode system env variable, to allow switching between cluster types. System.clearProperty("SPARK_YARN_MODE") SparkContext.clearActiveContext() From 6235132a8ce64bb12d825d0a65e5dd052d1ee647 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Tue, 2 May 2017 22:44:27 -0700 Subject: [PATCH 409/512] [SPARK-20567] Lazily bind in GenerateExec It is not valid to eagerly bind with the child's output as this causes failures when we attempt to canonicalize the plan (replacing the attribute references with dummies). Author: Michael Armbrust Closes #17838 from marmbrus/fixBindExplode. --- .../spark/sql/execution/GenerateExec.scala | 2 +- .../streaming/StreamingAggregationSuite.scala | 16 ++++++++++++++++ 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala index 1812a1152cb4..c35e5638e927 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala @@ -78,7 +78,7 @@ case class GenerateExec( override def outputPartitioning: Partitioning = child.outputPartitioning - val boundGenerator: Generator = BindReferences.bindReference(generator, child.output) + lazy val boundGenerator: Generator = BindReferences.bindReference(generator, child.output) protected override def doExecute(): RDD[InternalRow] = { // boundGenerator.terminate() should be triggered after all of the rows in the partition diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala index f796a4cb4a39..4345a70601c3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala @@ -69,6 +69,22 @@ class StreamingAggregationSuite extends StateStoreMetricsTest with BeforeAndAfte ) } + test("count distinct") { + val inputData = MemoryStream[(Int, Seq[Int])] + + val aggregated = + inputData.toDF() + .select($"*", explode($"_2") as 'value) + .groupBy($"_1") + .agg(size(collect_set($"value"))) + .as[(Int, Int)] + + testStream(aggregated, Update)( + AddData(inputData, (1, Seq(1, 2))), + CheckLastBatch((1, 2)) + ) + } + test("simple count, complete mode") { val inputData = MemoryStream[Int] From db2fb84b4a3c45daa449cc9232340193ce8eb37d Mon Sep 17 00:00:00 2001 From: MechCoder Date: Wed, 3 May 2017 10:58:05 +0200 Subject: [PATCH 410/512] [SPARK-6227][MLLIB][PYSPARK] Implement PySpark wrappers for SVD and PCA (v2) Add PCA and SVD to PySpark's wrappers for `RowMatrix` and `IndexedRowMatrix` (SVD only). Based on #7963, updated. ## How was this patch tested? New doc tests and unit tests. Ran all examples locally. Author: MechCoder Author: Nick Pentreath Closes #17621 from MLnick/SPARK-6227-pyspark-svd-pca. --- docs/mllib-dimensionality-reduction.md | 29 +-- .../spark/examples/mllib/JavaPCAExample.java | 27 ++- .../spark/examples/mllib/JavaSVDExample.java | 27 +-- .../python/mllib/pca_rowmatrix_example.py | 46 ++++ examples/src/main/python/mllib/svd_example.py | 48 +++++ .../mllib/PCAOnRowMatrixExample.scala | 4 +- .../spark/examples/mllib/SVDExample.scala | 11 +- python/pyspark/mllib/linalg/distributed.py | 199 +++++++++++++++++- python/pyspark/mllib/tests.py | 63 ++++++ 9 files changed, 408 insertions(+), 46 deletions(-) create mode 100644 examples/src/main/python/mllib/pca_rowmatrix_example.py create mode 100644 examples/src/main/python/mllib/svd_example.py diff --git a/docs/mllib-dimensionality-reduction.md b/docs/mllib-dimensionality-reduction.md index 539cbc1b3163..a72680d52a26 100644 --- a/docs/mllib-dimensionality-reduction.md +++ b/docs/mllib-dimensionality-reduction.md @@ -76,13 +76,14 @@ Refer to the [`SingularValueDecomposition` Java docs](api/java/org/apache/spark/ The same code applies to `IndexedRowMatrix` if `U` is defined as an `IndexedRowMatrix`. + +
    +Refer to the [`SingularValueDecomposition` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.linalg.distributed.SingularValueDecomposition) for details on the API. -In order to run the above application, follow the instructions -provided in the [Self-Contained -Applications](quick-start.html#self-contained-applications) section of the Spark -quick-start guide. Be sure to also include *spark-mllib* to your build file as -a dependency. +{% include_example python/mllib/svd_example.py %} +The same code applies to `IndexedRowMatrix` if `U` is defined as an +`IndexedRowMatrix`.
    @@ -118,17 +119,21 @@ Refer to the [`PCA` Scala docs](api/scala/index.html#org.apache.spark.mllib.feat The following code demonstrates how to compute principal components on a `RowMatrix` and use them to project the vectors into a low-dimensional space. -The number of columns should be small, e.g, less than 1000. Refer to the [`RowMatrix` Java docs](api/java/org/apache/spark/mllib/linalg/distributed/RowMatrix.html) for details on the API. {% include_example java/org/apache/spark/examples/mllib/JavaPCAExample.java %} - -In order to run the above application, follow the instructions -provided in the [Self-Contained Applications](quick-start.html#self-contained-applications) -section of the Spark -quick-start guide. Be sure to also include *spark-mllib* to your build file as -a dependency. +
    + +The following code demonstrates how to compute principal components on a `RowMatrix` +and use them to project the vectors into a low-dimensional space. + +Refer to the [`RowMatrix` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.linalg.distributed.RowMatrix) for details on the API. + +{% include_example python/mllib/pca_rowmatrix_example.py %} + +
    + diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaPCAExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaPCAExample.java index 3077f557ef88..0a7dc621e111 100644 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaPCAExample.java +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaPCAExample.java @@ -18,7 +18,8 @@ package org.apache.spark.examples.mllib; // $example on$ -import java.util.LinkedList; +import java.util.Arrays; +import java.util.List; // $example off$ import org.apache.spark.SparkConf; @@ -39,21 +40,25 @@ public class JavaPCAExample { public static void main(String[] args) { SparkConf conf = new SparkConf().setAppName("PCA Example"); SparkContext sc = new SparkContext(conf); + JavaSparkContext jsc = JavaSparkContext.fromSparkContext(sc); // $example on$ - double[][] array = {{1.12, 2.05, 3.12}, {5.56, 6.28, 8.94}, {10.2, 8.0, 20.5}}; - LinkedList rowsList = new LinkedList<>(); - for (int i = 0; i < array.length; i++) { - Vector currentRow = Vectors.dense(array[i]); - rowsList.add(currentRow); - } - JavaRDD rows = JavaSparkContext.fromSparkContext(sc).parallelize(rowsList); + List data = Arrays.asList( + Vectors.sparse(5, new int[] {1, 3}, new double[] {1.0, 7.0}), + Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0), + Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0) + ); + + JavaRDD rows = jsc.parallelize(data); // Create a RowMatrix from JavaRDD. RowMatrix mat = new RowMatrix(rows.rdd()); - // Compute the top 3 principal components. - Matrix pc = mat.computePrincipalComponents(3); + // Compute the top 4 principal components. + // Principal components are stored in a local dense matrix. + Matrix pc = mat.computePrincipalComponents(4); + + // Project the rows to the linear space spanned by the top 4 principal components. RowMatrix projected = mat.multiply(pc); // $example off$ Vector[] collectPartitions = (Vector[])projected.rows().collect(); @@ -61,6 +66,6 @@ public static void main(String[] args) { for (Vector vector : collectPartitions) { System.out.println("\t" + vector); } - sc.stop(); + jsc.stop(); } } diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaSVDExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaSVDExample.java index 3730e60f6880..802be3960a33 100644 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaSVDExample.java +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaSVDExample.java @@ -18,7 +18,8 @@ package org.apache.spark.examples.mllib; // $example on$ -import java.util.LinkedList; +import java.util.Arrays; +import java.util.List; // $example off$ import org.apache.spark.SparkConf; @@ -43,22 +44,22 @@ public static void main(String[] args) { JavaSparkContext jsc = JavaSparkContext.fromSparkContext(sc); // $example on$ - double[][] array = {{1.12, 2.05, 3.12}, {5.56, 6.28, 8.94}, {10.2, 8.0, 20.5}}; - LinkedList rowsList = new LinkedList<>(); - for (int i = 0; i < array.length; i++) { - Vector currentRow = Vectors.dense(array[i]); - rowsList.add(currentRow); - } - JavaRDD rows = jsc.parallelize(rowsList); + List data = Arrays.asList( + Vectors.sparse(5, new int[] {1, 3}, new double[] {1.0, 7.0}), + Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0), + Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0) + ); + + JavaRDD rows = jsc.parallelize(data); // Create a RowMatrix from JavaRDD. RowMatrix mat = new RowMatrix(rows.rdd()); - // Compute the top 3 singular values and corresponding singular vectors. - SingularValueDecomposition svd = mat.computeSVD(3, true, 1.0E-9d); - RowMatrix U = svd.U(); - Vector s = svd.s(); - Matrix V = svd.V(); + // Compute the top 5 singular values and corresponding singular vectors. + SingularValueDecomposition svd = mat.computeSVD(5, true, 1.0E-9d); + RowMatrix U = svd.U(); // The U factor is a RowMatrix. + Vector s = svd.s(); // The singular values are stored in a local dense vector. + Matrix V = svd.V(); // The V factor is a local dense matrix. // $example off$ Vector[] collectPartitions = (Vector[]) U.rows().collect(); System.out.println("U factor is:"); diff --git a/examples/src/main/python/mllib/pca_rowmatrix_example.py b/examples/src/main/python/mllib/pca_rowmatrix_example.py new file mode 100644 index 000000000000..49b9b1bbe08e --- /dev/null +++ b/examples/src/main/python/mllib/pca_rowmatrix_example.py @@ -0,0 +1,46 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from pyspark import SparkContext +# $example on$ +from pyspark.mllib.linalg import Vectors +from pyspark.mllib.linalg.distributed import RowMatrix +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="PythonPCAOnRowMatrixExample") + + # $example on$ + rows = sc.parallelize([ + Vectors.sparse(5, {1: 1.0, 3: 7.0}), + Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0), + Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0) + ]) + + mat = RowMatrix(rows) + # Compute the top 4 principal components. + # Principal components are stored in a local dense matrix. + pc = mat.computePrincipalComponents(4) + + # Project the rows to the linear space spanned by the top 4 principal components. + projected = mat.multiply(pc) + # $example off$ + collected = projected.rows.collect() + print("Projected Row Matrix of principal component:") + for vector in collected: + print(vector) + sc.stop() diff --git a/examples/src/main/python/mllib/svd_example.py b/examples/src/main/python/mllib/svd_example.py new file mode 100644 index 000000000000..5b220fdb3fd6 --- /dev/null +++ b/examples/src/main/python/mllib/svd_example.py @@ -0,0 +1,48 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from pyspark import SparkContext +# $example on$ +from pyspark.mllib.linalg import Vectors +from pyspark.mllib.linalg.distributed import RowMatrix +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="PythonSVDExample") + + # $example on$ + rows = sc.parallelize([ + Vectors.sparse(5, {1: 1.0, 3: 7.0}), + Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0), + Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0) + ]) + + mat = RowMatrix(rows) + + # Compute the top 5 singular values and corresponding singular vectors. + svd = mat.computeSVD(5, computeU=True) + U = svd.U # The U factor is a RowMatrix. + s = svd.s # The singular values are stored in a local dense vector. + V = svd.V # The V factor is a local dense matrix. + # $example off$ + collected = U.rows.collect() + print("U factor is:") + for vector in collected: + print(vector) + print("Singular values are: %s" % s) + print("V factor is:\n%s" % V) + sc.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/PCAOnRowMatrixExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/PCAOnRowMatrixExample.scala index a137ba2a2f9d..da43a8d9c7e8 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/PCAOnRowMatrixExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/PCAOnRowMatrixExample.scala @@ -39,9 +39,9 @@ object PCAOnRowMatrixExample { Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0), Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0)) - val dataRDD = sc.parallelize(data, 2) + val rows = sc.parallelize(data) - val mat: RowMatrix = new RowMatrix(dataRDD) + val mat: RowMatrix = new RowMatrix(rows) // Compute the top 4 principal components. // Principal components are stored in a local dense matrix. diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/SVDExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/SVDExample.scala index b286a3f7b909..769ae2a3a88b 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/SVDExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/SVDExample.scala @@ -28,6 +28,9 @@ import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.linalg.distributed.RowMatrix // $example off$ +/** + * Example for SingularValueDecomposition. + */ object SVDExample { def main(args: Array[String]): Unit = { @@ -41,15 +44,15 @@ object SVDExample { Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0), Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0)) - val dataRDD = sc.parallelize(data, 2) + val rows = sc.parallelize(data) - val mat: RowMatrix = new RowMatrix(dataRDD) + val mat: RowMatrix = new RowMatrix(rows) // Compute the top 5 singular values and corresponding singular vectors. val svd: SingularValueDecomposition[RowMatrix, Matrix] = mat.computeSVD(5, computeU = true) val U: RowMatrix = svd.U // The U factor is a RowMatrix. - val s: Vector = svd.s // The singular values are stored in a local dense vector. - val V: Matrix = svd.V // The V factor is a local dense matrix. + val s: Vector = svd.s // The singular values are stored in a local dense vector. + val V: Matrix = svd.V // The V factor is a local dense matrix. // $example off$ val collect = U.rows.collect() println("U factor is:") diff --git a/python/pyspark/mllib/linalg/distributed.py b/python/pyspark/mllib/linalg/distributed.py index 600655c912ca..4cb802514be5 100644 --- a/python/pyspark/mllib/linalg/distributed.py +++ b/python/pyspark/mllib/linalg/distributed.py @@ -28,14 +28,13 @@ from pyspark import RDD, since from pyspark.mllib.common import callMLlibFunc, JavaModelWrapper -from pyspark.mllib.linalg import _convert_to_vector, Matrix, QRDecomposition +from pyspark.mllib.linalg import _convert_to_vector, DenseMatrix, Matrix, QRDecomposition from pyspark.mllib.stat import MultivariateStatisticalSummary from pyspark.storagelevel import StorageLevel -__all__ = ['DistributedMatrix', 'RowMatrix', 'IndexedRow', - 'IndexedRowMatrix', 'MatrixEntry', 'CoordinateMatrix', - 'BlockMatrix'] +__all__ = ['BlockMatrix', 'CoordinateMatrix', 'DistributedMatrix', 'IndexedRow', + 'IndexedRowMatrix', 'MatrixEntry', 'RowMatrix', 'SingularValueDecomposition'] class DistributedMatrix(object): @@ -301,6 +300,136 @@ def tallSkinnyQR(self, computeQ=False): R = decomp.call("R") return QRDecomposition(Q, R) + @since('2.2.0') + def computeSVD(self, k, computeU=False, rCond=1e-9): + """ + Computes the singular value decomposition of the RowMatrix. + + The given row matrix A of dimension (m X n) is decomposed into + U * s * V'T where + + * U: (m X k) (left singular vectors) is a RowMatrix whose + columns are the eigenvectors of (A X A') + * s: DenseVector consisting of square root of the eigenvalues + (singular values) in descending order. + * v: (n X k) (right singular vectors) is a Matrix whose columns + are the eigenvectors of (A' X A) + + For more specific details on implementation, please refer + the Scala documentation. + + :param k: Number of leading singular values to keep (`0 < k <= n`). + It might return less than k if there are numerically zero singular values + or there are not enough Ritz values converged before the maximum number of + Arnoldi update iterations is reached (in case that matrix A is ill-conditioned). + :param computeU: Whether or not to compute U. If set to be + True, then U is computed by A * V * s^-1 + :param rCond: Reciprocal condition number. All singular values + smaller than rCond * s[0] are treated as zero + where s[0] is the largest singular value. + :returns: :py:class:`SingularValueDecomposition` + + >>> rows = sc.parallelize([[3, 1, 1], [-1, 3, 1]]) + >>> rm = RowMatrix(rows) + + >>> svd_model = rm.computeSVD(2, True) + >>> svd_model.U.rows.collect() + [DenseVector([-0.7071, 0.7071]), DenseVector([-0.7071, -0.7071])] + >>> svd_model.s + DenseVector([3.4641, 3.1623]) + >>> svd_model.V + DenseMatrix(3, 2, [-0.4082, -0.8165, -0.4082, 0.8944, -0.4472, 0.0], 0) + """ + j_model = self._java_matrix_wrapper.call( + "computeSVD", int(k), bool(computeU), float(rCond)) + return SingularValueDecomposition(j_model) + + @since('2.2.0') + def computePrincipalComponents(self, k): + """ + Computes the k principal components of the given row matrix + + .. note:: This cannot be computed on matrices with more than 65535 columns. + + :param k: Number of principal components to keep. + :returns: :py:class:`pyspark.mllib.linalg.DenseMatrix` + + >>> rows = sc.parallelize([[1, 2, 3], [2, 4, 5], [3, 6, 1]]) + >>> rm = RowMatrix(rows) + + >>> # Returns the two principal components of rm + >>> pca = rm.computePrincipalComponents(2) + >>> pca + DenseMatrix(3, 2, [-0.349, -0.6981, 0.6252, -0.2796, -0.5592, -0.7805], 0) + + >>> # Transform into new dimensions with the greatest variance. + >>> rm.multiply(pca).rows.collect() # doctest: +NORMALIZE_WHITESPACE + [DenseVector([0.1305, -3.7394]), DenseVector([-0.3642, -6.6983]), \ + DenseVector([-4.6102, -4.9745])] + """ + return self._java_matrix_wrapper.call("computePrincipalComponents", k) + + @since('2.2.0') + def multiply(self, matrix): + """ + Multiply this matrix by a local dense matrix on the right. + + :param matrix: a local dense matrix whose number of rows must match the number of columns + of this matrix + :returns: :py:class:`RowMatrix` + + >>> rm = RowMatrix(sc.parallelize([[0, 1], [2, 3]])) + >>> rm.multiply(DenseMatrix(2, 2, [0, 2, 1, 3])).rows.collect() + [DenseVector([2.0, 3.0]), DenseVector([6.0, 11.0])] + """ + if not isinstance(matrix, DenseMatrix): + raise ValueError("Only multiplication with DenseMatrix " + "is supported.") + j_model = self._java_matrix_wrapper.call("multiply", matrix) + return RowMatrix(j_model) + + +class SingularValueDecomposition(JavaModelWrapper): + """ + Represents singular value decomposition (SVD) factors. + + .. versionadded:: 2.2.0 + """ + + @property + @since('2.2.0') + def U(self): + """ + Returns a distributed matrix whose columns are the left + singular vectors of the SingularValueDecomposition if computeU was set to be True. + """ + u = self.call("U") + if u is not None: + mat_name = u.getClass().getSimpleName() + if mat_name == "RowMatrix": + return RowMatrix(u) + elif mat_name == "IndexedRowMatrix": + return IndexedRowMatrix(u) + else: + raise TypeError("Expected RowMatrix/IndexedRowMatrix got %s" % mat_name) + + @property + @since('2.2.0') + def s(self): + """ + Returns a DenseVector with singular values in descending order. + """ + return self.call("s") + + @property + @since('2.2.0') + def V(self): + """ + Returns a DenseMatrix whose columns are the right singular + vectors of the SingularValueDecomposition. + """ + return self.call("V") + class IndexedRow(object): """ @@ -528,6 +657,68 @@ def toBlockMatrix(self, rowsPerBlock=1024, colsPerBlock=1024): colsPerBlock) return BlockMatrix(java_block_matrix, rowsPerBlock, colsPerBlock) + @since('2.2.0') + def computeSVD(self, k, computeU=False, rCond=1e-9): + """ + Computes the singular value decomposition of the IndexedRowMatrix. + + The given row matrix A of dimension (m X n) is decomposed into + U * s * V'T where + + * U: (m X k) (left singular vectors) is a IndexedRowMatrix + whose columns are the eigenvectors of (A X A') + * s: DenseVector consisting of square root of the eigenvalues + (singular values) in descending order. + * v: (n X k) (right singular vectors) is a Matrix whose columns + are the eigenvectors of (A' X A) + + For more specific details on implementation, please refer + the scala documentation. + + :param k: Number of leading singular values to keep (`0 < k <= n`). + It might return less than k if there are numerically zero singular values + or there are not enough Ritz values converged before the maximum number of + Arnoldi update iterations is reached (in case that matrix A is ill-conditioned). + :param computeU: Whether or not to compute U. If set to be + True, then U is computed by A * V * s^-1 + :param rCond: Reciprocal condition number. All singular values + smaller than rCond * s[0] are treated as zero + where s[0] is the largest singular value. + :returns: SingularValueDecomposition object + + >>> rows = [(0, (3, 1, 1)), (1, (-1, 3, 1))] + >>> irm = IndexedRowMatrix(sc.parallelize(rows)) + >>> svd_model = irm.computeSVD(2, True) + >>> svd_model.U.rows.collect() # doctest: +NORMALIZE_WHITESPACE + [IndexedRow(0, [-0.707106781187,0.707106781187]),\ + IndexedRow(1, [-0.707106781187,-0.707106781187])] + >>> svd_model.s + DenseVector([3.4641, 3.1623]) + >>> svd_model.V + DenseMatrix(3, 2, [-0.4082, -0.8165, -0.4082, 0.8944, -0.4472, 0.0], 0) + """ + j_model = self._java_matrix_wrapper.call( + "computeSVD", int(k), bool(computeU), float(rCond)) + return SingularValueDecomposition(j_model) + + @since('2.2.0') + def multiply(self, matrix): + """ + Multiply this matrix by a local dense matrix on the right. + + :param matrix: a local dense matrix whose number of rows must match the number of columns + of this matrix + :returns: :py:class:`IndexedRowMatrix` + + >>> mat = IndexedRowMatrix(sc.parallelize([(0, (0, 1)), (1, (2, 3))])) + >>> mat.multiply(DenseMatrix(2, 2, [0, 2, 1, 3])).rows.collect() + [IndexedRow(0, [2.0,3.0]), IndexedRow(1, [6.0,11.0])] + """ + if not isinstance(matrix, DenseMatrix): + raise ValueError("Only multiplication with DenseMatrix " + "is supported.") + return IndexedRowMatrix(self._java_matrix_wrapper.call("multiply", matrix)) + class MatrixEntry(object): """ diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index 523b3f111331..1037bab7f108 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -23,6 +23,7 @@ import sys import tempfile import array as pyarray +from math import sqrt from time import time, sleep from shutil import rmtree @@ -54,6 +55,7 @@ from pyspark.mllib.clustering import StreamingKMeans, StreamingKMeansModel from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, VectorUDT, _convert_to_vector,\ DenseMatrix, SparseMatrix, Vectors, Matrices, MatrixUDT +from pyspark.mllib.linalg.distributed import RowMatrix from pyspark.mllib.classification import StreamingLogisticRegressionWithSGD from pyspark.mllib.recommendation import Rating from pyspark.mllib.regression import LabeledPoint, StreamingLinearRegressionWithSGD @@ -1699,6 +1701,67 @@ def test_binary_term_freqs(self): ": expected " + str(expected[i]) + ", got " + str(output[i])) +class DimensionalityReductionTests(MLlibTestCase): + + denseData = [ + Vectors.dense([0.0, 1.0, 2.0]), + Vectors.dense([3.0, 4.0, 5.0]), + Vectors.dense([6.0, 7.0, 8.0]), + Vectors.dense([9.0, 0.0, 1.0]) + ] + sparseData = [ + Vectors.sparse(3, [(1, 1.0), (2, 2.0)]), + Vectors.sparse(3, [(0, 3.0), (1, 4.0), (2, 5.0)]), + Vectors.sparse(3, [(0, 6.0), (1, 7.0), (2, 8.0)]), + Vectors.sparse(3, [(0, 9.0), (2, 1.0)]) + ] + + def assertEqualUpToSign(self, vecA, vecB): + eq1 = vecA - vecB + eq2 = vecA + vecB + self.assertTrue(sum(abs(eq1)) < 1e-6 or sum(abs(eq2)) < 1e-6) + + def test_svd(self): + denseMat = RowMatrix(self.sc.parallelize(self.denseData)) + sparseMat = RowMatrix(self.sc.parallelize(self.sparseData)) + m = 4 + n = 3 + for mat in [denseMat, sparseMat]: + for k in range(1, 4): + rm = mat.computeSVD(k, computeU=True) + self.assertEqual(rm.s.size, k) + self.assertEqual(rm.U.numRows(), m) + self.assertEqual(rm.U.numCols(), k) + self.assertEqual(rm.V.numRows, n) + self.assertEqual(rm.V.numCols, k) + + # Test that U returned is None if computeU is set to False. + self.assertEqual(mat.computeSVD(1).U, None) + + # Test that low rank matrices cannot have number of singular values + # greater than a limit. + rm = RowMatrix(self.sc.parallelize(tile([1, 2, 3], (3, 1)))) + self.assertEqual(rm.computeSVD(3, False, 1e-6).s.size, 1) + + def test_pca(self): + expected_pcs = array([ + [0.0, 1.0, 0.0], + [sqrt(2.0) / 2.0, 0.0, sqrt(2.0) / 2.0], + [sqrt(2.0) / 2.0, 0.0, -sqrt(2.0) / 2.0] + ]) + n = 3 + denseMat = RowMatrix(self.sc.parallelize(self.denseData)) + sparseMat = RowMatrix(self.sc.parallelize(self.sparseData)) + for mat in [denseMat, sparseMat]: + for k in range(1, 4): + pcs = mat.computePrincipalComponents(k) + self.assertEqual(pcs.numRows, n) + self.assertEqual(pcs.numCols, k) + + # We can just test the updated principal component for equality. + self.assertEqualUpToSign(pcs.toArray()[:, k - 1], expected_pcs[:, k - 1]) + + if __name__ == "__main__": from pyspark.mllib.tests import * if not _have_scipy: From 16fab6b0ef3dcb33f92df30e17680922ad5fb672 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Wed, 3 May 2017 10:18:35 +0100 Subject: [PATCH 411/512] [SPARK-20523][BUILD] Clean up build warnings for 2.2.0 release ## What changes were proposed in this pull request? Fix build warnings primarily related to Breeze 0.13 operator changes, Java style problems ## How was this patch tested? Existing tests Author: Sean Owen Closes #17803 from srowen/SPARK-20523. --- .../spark/network/yarn/YarnShuffleService.java | 4 ++-- .../java/org/apache/spark/unsafe/Platform.java | 3 ++- .../org/apache/spark/memory/TaskMemoryManager.java | 3 ++- .../spark/scheduler/TaskSetManagerSuite.scala | 11 ++++++----- .../storage/BlockReplicationPolicySuite.scala | 1 + dev/checkstyle-suppressions.xml | 4 ++++ .../streaming/JavaStructuredSessionization.java | 2 -- .../org/apache/spark/graphx/lib/PageRank.scala | 14 +++++++------- .../org/apache/spark/ml/ann/LossFunction.scala | 4 ++-- .../spark/ml/clustering/GaussianMixture.scala | 2 +- .../spark/mllib/clustering/GaussianMixture.scala | 2 +- .../apache/spark/mllib/clustering/LDAModel.scala | 8 ++++---- .../spark/mllib/clustering/LDAOptimizer.scala | 12 ++++++------ .../apache/spark/mllib/clustering/LDAUtils.scala | 2 +- .../spark/ml/classification/NaiveBayesSuite.scala | 2 +- pom.xml | 4 ---- .../cluster/YarnSchedulerBackendSuite.scala | 2 ++ .../spark/sql/streaming/GroupStateTimeout.java | 5 ++++- .../expressions/JsonExpressionsSuite.scala | 2 +- .../parquet/SpecificParquetRecordReaderBase.java | 5 +++-- .../spark/sql/execution/QueryExecutionSuite.scala | 2 ++ .../streaming/StreamingQueryListenerSuite.scala | 1 + .../spark/sql/hive/execution/HiveDDLSuite.scala | 2 +- 23 files changed, 54 insertions(+), 43 deletions(-) diff --git a/common/network-yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java b/common/network-yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java index 4acc203153e5..fd50e3a4bfb9 100644 --- a/common/network-yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java +++ b/common/network-yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java @@ -363,9 +363,9 @@ protected File initRecoveryDb(String dbName) { // If another DB was initialized first just make sure all the DBs are in the same // location. Path newLoc = new Path(_recoveryPath, dbName); - Path copyFrom = new Path(f.toURI()); + Path copyFrom = new Path(f.toURI()); if (!newLoc.equals(copyFrom)) { - logger.info("Moving " + copyFrom + " to: " + newLoc); + logger.info("Moving " + copyFrom + " to: " + newLoc); try { // The move here needs to handle moving non-empty directories across NFS mounts FileSystem fs = FileSystem.getLocal(_conf); diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java index 1321b8318115..4ab5b6889c21 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java @@ -48,7 +48,8 @@ public final class Platform { boolean _unaligned; String arch = System.getProperty("os.arch", ""); if (arch.equals("ppc64le") || arch.equals("ppc64")) { - // Since java.nio.Bits.unaligned() doesn't return true on ppc (See JDK-8165231), but ppc64 and ppc64le support it + // Since java.nio.Bits.unaligned() doesn't return true on ppc (See JDK-8165231), but + // ppc64 and ppc64le support it _unaligned = true; } else { try { diff --git a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java index aa0b37323132..5f9141174916 100644 --- a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java +++ b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java @@ -155,7 +155,8 @@ public long acquireExecutionMemory(long required, MemoryConsumer consumer) { for (MemoryConsumer c: consumers) { if (c != consumer && c.getUsed() > 0 && c.getMode() == mode) { long key = c.getUsed(); - List list = sortedConsumers.computeIfAbsent(key, k -> new ArrayList<>(1)); + List list = + sortedConsumers.computeIfAbsent(key, k -> new ArrayList<>(1)); list.add(c); } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index 9ca6b8b0fe63..db14c9acfdce 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -1070,11 +1070,12 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg sched.dagScheduler = mockDAGScheduler val taskSet = FakeTask.createTaskSet(numTasks = 1, stageId = 0, stageAttemptId = 0) val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock = new ManualClock(1)) - when(mockDAGScheduler.taskEnded(any(), any(), any(), any(), any())).then(new Answer[Unit] { - override def answer(invocationOnMock: InvocationOnMock): Unit = { - assert(manager.isZombie === true) - } - }) + when(mockDAGScheduler.taskEnded(any(), any(), any(), any(), any())).thenAnswer( + new Answer[Unit] { + override def answer(invocationOnMock: InvocationOnMock): Unit = { + assert(manager.isZombie) + } + }) val taskOption = manager.resourceOffer("exec1", "host1", NO_PREF) assert(taskOption.isDefined) // this would fail, inside our mock dag scheduler, if it calls dagScheduler.taskEnded() too soon diff --git a/core/src/test/scala/org/apache/spark/storage/BlockReplicationPolicySuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockReplicationPolicySuite.scala index dfecd04c1b96..4000218e71a8 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockReplicationPolicySuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockReplicationPolicySuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.storage import scala.collection.mutable +import scala.language.implicitConversions import scala.util.Random import org.scalatest.{BeforeAndAfter, Matchers} diff --git a/dev/checkstyle-suppressions.xml b/dev/checkstyle-suppressions.xml index 31656ca0e5a6..bb7d31cad7be 100644 --- a/dev/checkstyle-suppressions.xml +++ b/dev/checkstyle-suppressions.xml @@ -44,4 +44,8 @@ files="src/main/java/org/apache/hive/service/server/ThreadWithGarbageCleanup.java"/> + + diff --git a/examples/src/main/java/org/apache/spark/examples/sql/streaming/JavaStructuredSessionization.java b/examples/src/main/java/org/apache/spark/examples/sql/streaming/JavaStructuredSessionization.java index d3c8516882fa..6b8e6554f1bb 100644 --- a/examples/src/main/java/org/apache/spark/examples/sql/streaming/JavaStructuredSessionization.java +++ b/examples/src/main/java/org/apache/spark/examples/sql/streaming/JavaStructuredSessionization.java @@ -28,8 +28,6 @@ import java.sql.Timestamp; import java.util.*; -import scala.Tuple2; - /** * Counts words in UTF8 encoded, '\n' delimited text received from the network. *

    diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala index 13b2b5771918..fd7b7f7c1c48 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala @@ -226,18 +226,18 @@ object PageRank extends Logging { // Propagates the message along outbound edges // and adding start nodes back in with activation resetProb val rankUpdates = rankGraph.aggregateMessages[BV[Double]]( - ctx => ctx.sendToDst(ctx.srcAttr :* ctx.attr), - (a : BV[Double], b : BV[Double]) => a :+ b, TripletFields.Src) + ctx => ctx.sendToDst(ctx.srcAttr *:* ctx.attr), + (a : BV[Double], b : BV[Double]) => a +:+ b, TripletFields.Src) rankGraph = rankGraph.outerJoinVertices(rankUpdates) { (vid, oldRank, msgSumOpt) => - val popActivations: BV[Double] = msgSumOpt.getOrElse(zero) :* (1.0 - resetProb) + val popActivations: BV[Double] = msgSumOpt.getOrElse(zero) *:* (1.0 - resetProb) val resetActivations = if (sourcesInitMapBC.value contains vid) { - sourcesInitMapBC.value(vid) :* resetProb + sourcesInitMapBC.value(vid) *:* resetProb } else { zero } - popActivations :+ resetActivations + popActivations +:+ resetActivations }.cache() rankGraph.edges.foreachPartition(x => {}) // also materializes rankGraph.vertices @@ -250,9 +250,9 @@ object PageRank extends Logging { } // SPARK-18847 If the graph has sinks (vertices with no outgoing edges) correct the sum of ranks - val rankSums = rankGraph.vertices.values.fold(zero)(_ :+ _) + val rankSums = rankGraph.vertices.values.fold(zero)(_ +:+ _) rankGraph.mapVertices { (vid, attr) => - Vectors.fromBreeze(attr :/ rankSums) + Vectors.fromBreeze(attr /:/ rankSums) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/ann/LossFunction.scala b/mllib/src/main/scala/org/apache/spark/ml/ann/LossFunction.scala index 32d78e9b226e..3aea568cd652 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/ann/LossFunction.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/ann/LossFunction.scala @@ -56,7 +56,7 @@ private[ann] class SigmoidLayerModelWithSquaredError extends FunctionalLayerModel(new FunctionalLayer(new SigmoidFunction)) with LossFunction { override def loss(output: BDM[Double], target: BDM[Double], delta: BDM[Double]): Double = { ApplyInPlace(output, target, delta, (o: Double, t: Double) => o - t) - val error = Bsum(delta :* delta) / 2 / output.cols + val error = Bsum(delta *:* delta) / 2 / output.cols ApplyInPlace(delta, output, delta, (x: Double, o: Double) => x * (o - o * o)) error } @@ -119,6 +119,6 @@ private[ann] class SoftmaxLayerModelWithCrossEntropyLoss extends LayerModel with override def loss(output: BDM[Double], target: BDM[Double], delta: BDM[Double]): Double = { ApplyInPlace(output, target, delta, (o: Double, t: Double) => o - t) - -Bsum( target :* brzlog(output)) / output.cols + -Bsum( target *:* brzlog(output)) / output.cols } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala index a9c1a7ba0bc8..5259ee419445 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala @@ -472,7 +472,7 @@ class GaussianMixture @Since("2.0.0") ( */ val cov = { val ss = new DenseVector(new Array[Double](numFeatures)).asBreeze - slice.foreach(xi => ss += (xi.asBreeze - mean.asBreeze) :^ 2.0) + slice.foreach(xi => ss += (xi.asBreeze - mean.asBreeze) ^:^ 2.0) val diagVec = Vectors.fromBreeze(ss) BLAS.scal(1.0 / numSamples, diagVec) val covVec = new DenseVector(Array.fill[Double]( diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala index 051ec2404fb6..4d952ac88c9b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala @@ -271,7 +271,7 @@ class GaussianMixture private ( private def initCovariance(x: IndexedSeq[BV[Double]]): BreezeMatrix[Double] = { val mu = vectorMean(x) val ss = BDV.zeros[Double](x(0).length) - x.foreach(xi => ss += (xi - mu) :^ 2.0) + x.foreach(xi => ss += (xi - mu) ^:^ 2.0) diag(ss / x.length.toDouble) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala index 15b723dadcff..663f63c25a94 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala @@ -314,7 +314,7 @@ class LocalLDAModel private[spark] ( docBound += count * LDAUtils.logSumExp(Elogthetad + localElogbeta(idx, ::).t) } // E[log p(theta | alpha) - log q(theta | gamma)] - docBound += sum((brzAlpha - gammad) :* Elogthetad) + docBound += sum((brzAlpha - gammad) *:* Elogthetad) docBound += sum(lgamma(gammad) - lgamma(brzAlpha)) docBound += lgamma(sum(brzAlpha)) - lgamma(sum(gammad)) @@ -324,7 +324,7 @@ class LocalLDAModel private[spark] ( // Bound component for prob(topic-term distributions): // E[log p(beta | eta) - log q(beta | lambda)] val sumEta = eta * vocabSize - val topicsPart = sum((eta - lambda) :* Elogbeta) + + val topicsPart = sum((eta - lambda) *:* Elogbeta) + sum(lgamma(lambda) - lgamma(eta)) + sum(lgamma(sumEta) - lgamma(sum(lambda(::, breeze.linalg.*)))) @@ -721,7 +721,7 @@ class DistributedLDAModel private[clustering] ( val N_wj = edgeContext.attr val smoothed_N_wk: TopicCounts = edgeContext.dstAttr + (eta - 1.0) val smoothed_N_kj: TopicCounts = edgeContext.srcAttr + (alpha - 1.0) - val phi_wk: TopicCounts = smoothed_N_wk :/ smoothed_N_k + val phi_wk: TopicCounts = smoothed_N_wk /:/ smoothed_N_k val theta_kj: TopicCounts = normalize(smoothed_N_kj, 1.0) val tokenLogLikelihood = N_wj * math.log(phi_wk.dot(theta_kj)) edgeContext.sendToDst(tokenLogLikelihood) @@ -748,7 +748,7 @@ class DistributedLDAModel private[clustering] ( if (isTermVertex(vertex)) { val N_wk = vertex._2 val smoothed_N_wk: TopicCounts = N_wk + (eta - 1.0) - val phi_wk: TopicCounts = smoothed_N_wk :/ smoothed_N_k + val phi_wk: TopicCounts = smoothed_N_wk /:/ smoothed_N_k sumPrior + (eta - 1.0) * sum(phi_wk.map(math.log)) } else { val N_kj = vertex._2 diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala index 3697a9b46dd8..d633893e55f5 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala @@ -482,7 +482,7 @@ final class OnlineLDAOptimizer extends LDAOptimizer { stats.map(_._2).flatMap(list => list).collect().map(_.toDenseMatrix): _*) stats.unpersist() expElogbetaBc.destroy(false) - val batchResult = statsSum :* expElogbeta.t + val batchResult = statsSum *:* expElogbeta.t // Note that this is an optimization to avoid batch.count updateLambda(batchResult, (miniBatchFraction * corpusSize).ceil.toInt) @@ -522,7 +522,7 @@ final class OnlineLDAOptimizer extends LDAOptimizer { val dalpha = -(gradf - b) / q - if (all((weight * dalpha + alpha) :> 0D)) { + if (all((weight * dalpha + alpha) >:> 0D)) { alpha :+= weight * dalpha this.alpha = Vectors.dense(alpha.toArray) } @@ -584,7 +584,7 @@ private[clustering] object OnlineLDAOptimizer { val expElogthetad: BDV[Double] = exp(LDAUtils.dirichletExpectation(gammad)) // K val expElogbetad = expElogbeta(ids, ::).toDenseMatrix // ids * K - val phiNorm: BDV[Double] = expElogbetad * expElogthetad :+ 1e-100 // ids + val phiNorm: BDV[Double] = expElogbetad * expElogthetad +:+ 1e-100 // ids var meanGammaChange = 1D val ctsVector = new BDV[Double](cts) // ids @@ -592,14 +592,14 @@ private[clustering] object OnlineLDAOptimizer { while (meanGammaChange > 1e-3) { val lastgamma = gammad.copy // K K * ids ids - gammad := (expElogthetad :* (expElogbetad.t * (ctsVector :/ phiNorm))) :+ alpha + gammad := (expElogthetad *:* (expElogbetad.t * (ctsVector /:/ phiNorm))) +:+ alpha expElogthetad := exp(LDAUtils.dirichletExpectation(gammad)) // TODO: Keep more values in log space, and only exponentiate when needed. - phiNorm := expElogbetad * expElogthetad :+ 1e-100 + phiNorm := expElogbetad * expElogthetad +:+ 1e-100 meanGammaChange = sum(abs(gammad - lastgamma)) / k } - val sstatsd = expElogthetad.asDenseMatrix.t * (ctsVector :/ phiNorm).asDenseMatrix + val sstatsd = expElogthetad.asDenseMatrix.t * (ctsVector /:/ phiNorm).asDenseMatrix (gammad, sstatsd, ids) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAUtils.scala index 1f6e1a077f92..c4bbe51a46c3 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAUtils.scala @@ -29,7 +29,7 @@ private[clustering] object LDAUtils { */ private[clustering] def logSumExp(x: BDV[Double]): Double = { val a = max(x) - a + log(sum(exp(x :- a))) + a + log(sum(exp(x -:- a))) } /** diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala index b56f8e19ca53..3a2be236f125 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala @@ -168,7 +168,7 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa assert(m1.pi ~== m2.pi relTol 0.01) assert(m1.theta ~== m2.theta relTol 0.01) } - val testParams = Seq( + val testParams = Seq[(String, Dataset[_])]( ("bernoulli", bernoulliDataset), ("multinomial", dataset) ) diff --git a/pom.xml b/pom.xml index 517ebc5c83fc..a1a1817e2f7d 100644 --- a/pom.xml +++ b/pom.xml @@ -58,10 +58,6 @@ https://issues.apache.org/jira/browse/SPARK - - ${maven.version} - - Dev Mailing List diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackendSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackendSuite.scala index 4079d9e40fc4..0a413b2c23de 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackendSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackendSuite.scala @@ -16,6 +16,8 @@ */ package org.apache.spark.scheduler.cluster +import scala.language.reflectiveCalls + import org.mockito.Mockito.when import org.scalatest.mock.MockitoSugar diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/GroupStateTimeout.java b/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/GroupStateTimeout.java index bd5e2d7ecca9..5f1032d1229d 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/GroupStateTimeout.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/GroupStateTimeout.java @@ -37,7 +37,9 @@ public class GroupStateTimeout { * `map/flatMapGroupsWithState` by calling `GroupState.setTimeoutDuration()`. See documentation * on `GroupState` for more details. */ - public static GroupStateTimeout ProcessingTimeTimeout() { return ProcessingTimeTimeout$.MODULE$; } + public static GroupStateTimeout ProcessingTimeTimeout() { + return ProcessingTimeTimeout$.MODULE$; + } /** * Timeout based on event-time. The event-time timestamp for timeout can be set for each @@ -51,4 +53,5 @@ public class GroupStateTimeout { /** No timeout. */ public static GroupStateTimeout NoTimeout() { return NoTimeout$.MODULE$; } + } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala index 65d5c3a582b1..f892e8020460 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala @@ -41,7 +41,7 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { /* invalid json with leading nulls would trigger java.io.CharConversionException in Jackson's JsonFactory.createParser(byte[]) due to RFC-4627 encoding detection */ - val badJson = "\0\0\0A\1AAA" + val badJson = "\u0000\u0000\u0000A\u0001AAA" test("$.store.bicycle") { checkEvaluation( diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java index 0bab321a657d..5a810cae1e18 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java @@ -66,7 +66,6 @@ import org.apache.spark.sql.types.StructType; import org.apache.spark.sql.types.StructType$; import org.apache.spark.util.AccumulatorV2; -import org.apache.spark.util.LongAccumulator; /** * Base class for custom RecordReaders for Parquet that directly materialize to `T`. @@ -160,7 +159,9 @@ public void initialize(InputSplit inputSplit, TaskAttemptContext taskAttemptCont if (taskContext != null) { Option> accu = taskContext.taskMetrics().externalAccums().lastOption(); if (accu.isDefined() && accu.get().getClass().getSimpleName().equals("NumRowGroupsAcc")) { - ((AccumulatorV2)accu.get()).add(blocks.size()); + @SuppressWarnings("unchecked") + AccumulatorV2 intAccum = (AccumulatorV2) accu.get(); + intAccum.add(blocks.size()); } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala index 1c1931b6a6da..05637821f71f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala @@ -18,6 +18,8 @@ package org.apache.spark.sql.execution import java.util.Locale +import scala.language.reflectiveCalls + import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, OneRowRelation} import org.apache.spark.sql.test.SharedSQLContext diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala index b8a694c17731..59c6a6fade17 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala @@ -21,6 +21,7 @@ import java.util.UUID import scala.collection.mutable import scala.concurrent.duration._ +import scala.language.reflectiveCalls import org.scalactic.TolerantNumerics import org.scalatest.concurrent.AsyncAssertions.Waiter diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala index 341e03b5e57f..c3d734e5a036 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -183,7 +183,7 @@ class HiveDDLSuite if (dbPath.isEmpty) { hiveContext.sessionState.catalog.defaultTablePath(tableIdentifier) } else { - new Path(new Path(dbPath.get), tableIdentifier.table) + new Path(new Path(dbPath.get), tableIdentifier.table).toUri } val filesystemPath = new Path(expectedTablePath.toString) val fs = filesystemPath.getFileSystem(spark.sessionState.newHadoopConf()) From 7f96f2d7f2d5abf81dd7f8ca27fea35cf798fd65 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Yan=20Facai=20=28=E9=A2=9C=E5=8F=91=E6=89=8D=29?= Date: Wed, 3 May 2017 10:54:40 +0100 Subject: [PATCH 412/512] [SPARK-16957][MLLIB] Use midpoints for split values. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? Use midpoints for split values now, and maybe later to make it weighted. ## How was this patch tested? + [x] add unit test. + [x] revise Split's unit test. Author: Yan Facai (颜发才) Author: 颜发才(Yan Facai) Closes #17556 from facaiy/ENH/decision_tree_overflow_and_precision_in_aggregation. --- .../spark/ml/tree/impl/RandomForest.scala | 15 ++++--- .../ml/tree/impl/RandomForestSuite.scala | 41 ++++++++++++++++--- python/pyspark/mllib/tree.py | 12 +++--- 3 files changed, 51 insertions(+), 17 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index 008dd19c2498..82e1ed85a0a1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -996,7 +996,7 @@ private[spark] object RandomForest extends Logging { require(metadata.isContinuous(featureIndex), "findSplitsForContinuousFeature can only be used to find splits for a continuous feature.") - val splits = if (featureSamples.isEmpty) { + val splits: Array[Double] = if (featureSamples.isEmpty) { Array.empty[Double] } else { val numSplits = metadata.numSplits(featureIndex) @@ -1009,10 +1009,15 @@ private[spark] object RandomForest extends Logging { // sort distinct values val valueCounts = valueCountMap.toSeq.sortBy(_._1).toArray - // if possible splits is not enough or just enough, just return all possible splits val possibleSplits = valueCounts.length - 1 - if (possibleSplits <= numSplits) { - valueCounts.map(_._1).init + if (possibleSplits == 0) { + // constant feature + Array.empty[Double] + } else if (possibleSplits <= numSplits) { + // if possible splits is not enough or just enough, just return all possible splits + (1 to possibleSplits) + .map(index => (valueCounts(index - 1)._1 + valueCounts(index)._1) / 2.0) + .toArray } else { // stride between splits val stride: Double = numSamples.toDouble / (numSplits + 1) @@ -1037,7 +1042,7 @@ private[spark] object RandomForest extends Logging { // makes the gap between currentCount and targetCount smaller, // previous value is a split threshold. if (previousGap < currentGap) { - splitsBuilder += valueCounts(index - 1)._1 + splitsBuilder += (valueCounts(index - 1)._1 + valueCounts(index)._1) / 2.0 targetCount += stride } index += 1 diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala index e1ab7c2d6520..df155b464c64 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala @@ -104,6 +104,31 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { assert(splits.distinct.length === splits.length) } + // SPARK-16957: Use midpoints for split values. + { + val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0, + Map(), Set(), + Array(3), Gini, QuantileStrategy.Sort, + 0, 0, 0.0, 0, 0 + ) + + // possibleSplits <= numSplits + { + val featureSamples = Array(0, 1, 0, 0, 1, 0, 1, 1).map(_.toDouble) + val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) + val expectedSplits = Array((0.0 + 1.0) / 2) + assert(splits === expectedSplits) + } + + // possibleSplits > numSplits + { + val featureSamples = Array(0, 0, 1, 1, 2, 2, 3, 3).map(_.toDouble) + val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) + val expectedSplits = Array((0.0 + 1.0) / 2, (2.0 + 3.0) / 2) + assert(splits === expectedSplits) + } + } + // find splits should not return identical splits // when there are not enough split candidates, reduce the number of splits in metadata { @@ -112,9 +137,10 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { Array(5), Gini, QuantileStrategy.Sort, 0, 0, 0.0, 0, 0 ) - val featureSamples = Array(1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3).map(_.toDouble) + val featureSamples = Array(1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3).map(_.toDouble) val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) - assert(splits === Array(1.0, 2.0)) + val expectedSplits = Array((1.0 + 2.0) / 2, (2.0 + 3.0) / 2) + assert(splits === expectedSplits) // check returned splits are distinct assert(splits.distinct.length === splits.length) } @@ -126,9 +152,11 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { Array(3), Gini, QuantileStrategy.Sort, 0, 0, 0.0, 0, 0 ) - val featureSamples = Array(2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 4, 5).map(_.toDouble) + val featureSamples = Array(2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 4, 5) + .map(_.toDouble) val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) - assert(splits === Array(2.0, 3.0)) + val expectedSplits = Array((2.0 + 3.0) / 2, (3.0 + 4.0) / 2) + assert(splits === expectedSplits) } // find splits when most samples close to the maximum @@ -138,9 +166,10 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { Array(2), Gini, QuantileStrategy.Sort, 0, 0, 0.0, 0, 0 ) - val featureSamples = Array(0, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2).map(_.toDouble) + val featureSamples = Array(0, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2).map(_.toDouble) val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) - assert(splits === Array(1.0)) + val expectedSplits = Array((1.0 + 2.0) / 2) + assert(splits === expectedSplits) } // find splits for constant feature diff --git a/python/pyspark/mllib/tree.py b/python/pyspark/mllib/tree.py index a6089fc8b9d3..619fa16d463f 100644 --- a/python/pyspark/mllib/tree.py +++ b/python/pyspark/mllib/tree.py @@ -199,9 +199,9 @@ def trainClassifier(cls, data, numClasses, categoricalFeaturesInfo, >>> print(model.toDebugString()) DecisionTreeModel classifier of depth 1 with 3 nodes - If (feature 0 <= 0.0) + If (feature 0 <= 0.5) Predict: 0.0 - Else (feature 0 > 0.0) + Else (feature 0 > 0.5) Predict: 1.0 >>> model.predict(array([1.0])) @@ -383,14 +383,14 @@ def trainClassifier(cls, data, numClasses, categoricalFeaturesInfo, numTrees, Tree 0: Predict: 1.0 Tree 1: - If (feature 0 <= 1.0) + If (feature 0 <= 1.5) Predict: 0.0 - Else (feature 0 > 1.0) + Else (feature 0 > 1.5) Predict: 1.0 Tree 2: - If (feature 0 <= 1.0) + If (feature 0 <= 1.5) Predict: 0.0 - Else (feature 0 > 1.0) + Else (feature 0 > 1.5) Predict: 1.0 >>> model.predict([2.0]) From 27f543b15f2f493f6f8373e46b4c9564b0a1bf81 Mon Sep 17 00:00:00 2001 From: Liwei Lin Date: Wed, 3 May 2017 08:55:02 -0700 Subject: [PATCH 413/512] [SPARK-20441][SPARK-20432][SS] Within the same streaming query, one StreamingRelation should only be transformed to one StreamingExecutionRelation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? Within the same streaming query, when one `StreamingRelation` is referred multiple times – e.g. `df.union(df)` – we should transform it only to one `StreamingExecutionRelation`, instead of two or more different `StreamingExecutionRelation`s (each of which would have a separate set of source, source logs, ...). ## How was this patch tested? Added two test cases, each of which would fail without this patch. Author: Liwei Lin Closes #17735 from lw-lin/SPARK-20441. --- .../execution/streaming/StreamExecution.scala | 20 ++++---- .../spark/sql/streaming/StreamSuite.scala | 48 +++++++++++++++++++ 2 files changed, 60 insertions(+), 8 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index affc2018c43c..b6ddf7437ea1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -23,6 +23,7 @@ import java.util.concurrent.{CountDownLatch, TimeUnit} import java.util.concurrent.atomic.AtomicReference import java.util.concurrent.locks.ReentrantLock +import scala.collection.mutable.{Map => MutableMap} import scala.collection.mutable.ArrayBuffer import scala.util.control.NonFatal @@ -148,15 +149,18 @@ class StreamExecution( "logicalPlan must be initialized in StreamExecutionThread " + s"but the current thread was ${Thread.currentThread}") var nextSourceId = 0L + val toExecutionRelationMap = MutableMap[StreamingRelation, StreamingExecutionRelation]() val _logicalPlan = analyzedPlan.transform { - case StreamingRelation(dataSource, _, output) => - // Materialize source to avoid creating it in every batch - val metadataPath = s"$checkpointRoot/sources/$nextSourceId" - val source = dataSource.createSource(metadataPath) - nextSourceId += 1 - // We still need to use the previous `output` instead of `source.schema` as attributes in - // "df.logicalPlan" has already used attributes of the previous `output`. - StreamingExecutionRelation(source, output) + case streamingRelation@StreamingRelation(dataSource, _, output) => + toExecutionRelationMap.getOrElseUpdate(streamingRelation, { + // Materialize source to avoid creating it in every batch + val metadataPath = s"$checkpointRoot/sources/$nextSourceId" + val source = dataSource.createSource(metadataPath) + nextSourceId += 1 + // We still need to use the previous `output` instead of `source.schema` as attributes in + // "df.logicalPlan" has already used attributes of the previous `output`. + StreamingExecutionRelation(source, output) + }) } sources = _logicalPlan.collect { case s: StreamingExecutionRelation => s.source } uniqueSources = sources.distinct diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala index 01ea62a9de4d..1fc062974e18 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala @@ -71,6 +71,27 @@ class StreamSuite extends StreamTest { CheckAnswer(Row(1, 1, "one"), Row(2, 2, "two"), Row(4, 4, "four"))) } + test("SPARK-20432: union one stream with itself") { + val df = spark.readStream.format(classOf[FakeDefaultSource].getName).load().select("a") + val unioned = df.union(df) + withTempDir { outputDir => + withTempDir { checkpointDir => + val query = + unioned + .writeStream.format("parquet") + .option("checkpointLocation", checkpointDir.getAbsolutePath) + .start(outputDir.getAbsolutePath) + try { + query.processAllAvailable() + val outputDf = spark.read.parquet(outputDir.getAbsolutePath).as[Long] + checkDatasetUnorderly[Long](outputDf, (0L to 10L).union((0L to 10L)).toArray: _*) + } finally { + query.stop() + } + } + } + } + test("union two streams") { val inputData1 = MemoryStream[Int] val inputData2 = MemoryStream[Int] @@ -122,6 +143,33 @@ class StreamSuite extends StreamTest { assertDF(df) } + test("Within the same streaming query, one StreamingRelation should only be transformed to one " + + "StreamingExecutionRelation") { + val df = spark.readStream.format(classOf[FakeDefaultSource].getName).load() + var query: StreamExecution = null + try { + query = + df.union(df) + .writeStream + .format("memory") + .queryName("memory") + .start() + .asInstanceOf[StreamingQueryWrapper] + .streamingQuery + query.awaitInitialization(streamingTimeout.toMillis) + val executionRelations = + query + .logicalPlan + .collect { case ser: StreamingExecutionRelation => ser } + assert(executionRelations.size === 2) + assert(executionRelations.distinct.size === 1) + } finally { + if (query != null) { + query.stop() + } + } + } + test("unsupported queries") { val streamInput = MemoryStream[Int] val batchInput = Seq(1, 2, 3).toDS() From 527fc5d0c990daaacad4740f62cfe6736609b77b Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 3 May 2017 09:22:25 -0700 Subject: [PATCH 414/512] [SPARK-20576][SQL] Support generic hint function in Dataset/DataFrame ## What changes were proposed in this pull request? We allow users to specify hints (currently only "broadcast" is supported) in SQL and DataFrame. However, while SQL has a standard hint format (/*+ ... */), DataFrame doesn't have one and sometimes users are confused that they can't find how to apply a broadcast hint. This ticket adds a generic hint function on DataFrame that allows using the same hint on DataFrames as well as SQL. As an example, after this patch, the following will apply a broadcast hint on a DataFrame using the new hint function: ``` df1.join(df2.hint("broadcast")) ``` ## How was this patch tested? Added a test case in DataFrameJoinSuite. Author: Reynold Xin Closes #17839 from rxin/SPARK-20576. --- .../sql/catalyst/analysis/ResolveHints.scala | 8 +++++++- .../scala/org/apache/spark/sql/Dataset.scala | 16 ++++++++++++++++ .../apache/spark/sql/DataFrameJoinSuite.scala | 18 +++++++++++++++++- 3 files changed, 40 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala index c4827b81e8b6..df688fa0e58a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala @@ -86,7 +86,13 @@ object ResolveHints { def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case h: Hint if BROADCAST_HINT_NAMES.contains(h.name.toUpperCase(Locale.ROOT)) => - applyBroadcastHint(h.child, h.parameters.toSet) + if (h.parameters.isEmpty) { + // If there is no table alias specified, turn the entire subtree into a BroadcastHint. + BroadcastHint(h.child) + } else { + // Otherwise, find within the subtree query plans that should be broadcasted. + applyBroadcastHint(h.child, h.parameters.toSet) + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 147e7651ce55..620c8bd54ba0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -1160,6 +1160,22 @@ class Dataset[T] private[sql]( */ def apply(colName: String): Column = col(colName) + /** + * Specifies some hint on the current Dataset. As an example, the following code specifies + * that one of the plan can be broadcasted: + * + * {{{ + * df1.join(df2.hint("broadcast")) + * }}} + * + * @group basic + * @since 2.2.0 + */ + @scala.annotation.varargs + def hint(name: String, parameters: String*): Dataset[T] = withTypedPlan { + Hint(name, parameters, logicalPlan) + } + /** * Selects column based on the column name and return it as a [[Column]]. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala index 541ffb58e727..4a52af6c32c3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala @@ -151,7 +151,7 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext { Row(1, 1, 1, 1) :: Row(2, 1, 2, 2) :: Nil) } - test("broadcast join hint") { + test("broadcast join hint using broadcast function") { val df1 = Seq((1, "1"), (2, "2")).toDF("key", "value") val df2 = Seq((1, "1"), (2, "2")).toDF("key", "value") @@ -174,6 +174,22 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext { } } + test("broadcast join hint using Dataset.hint") { + // make sure a giant join is not broadcastable + val plan1 = + spark.range(10e10.toLong) + .join(spark.range(10e10.toLong), "id") + .queryExecution.executedPlan + assert(plan1.collect { case p: BroadcastHashJoinExec => p }.size == 0) + + // now with a hint it should be broadcasted + val plan2 = + spark.range(10e10.toLong) + .join(spark.range(10e10.toLong).hint("broadcast"), "id") + .queryExecution.executedPlan + assert(plan2.collect { case p: BroadcastHashJoinExec => p }.size == 1) + } + test("join - outer join conversion") { val df = Seq((1, 2, "1"), (3, 4, "3")).toDF("int", "int2", "str").as("a") val df2 = Seq((1, 3, "1"), (5, 6, "5")).toDF("int", "int2", "str").as("b") From 6b9e49d12fc4c9b29d497122daa4cc9bf4540b16 Mon Sep 17 00:00:00 2001 From: Liwei Lin Date: Wed, 3 May 2017 11:10:24 -0700 Subject: [PATCH 415/512] [SPARK-19965][SS] DataFrame batch reader may fail to infer partitions when reading FileStreamSink's output ## The Problem Right now DataFrame batch reader may fail to infer partitions when reading FileStreamSink's output: ``` [info] - partitioned writing and batch reading with 'basePath' *** FAILED *** (3 seconds, 928 milliseconds) [info] java.lang.AssertionError: assertion failed: Conflicting directory structures detected. Suspicious paths: [info] ***/stream.output-65e3fa45-595a-4d29-b3df-4c001e321637 [info] ***/stream.output-65e3fa45-595a-4d29-b3df-4c001e321637/_spark_metadata [info] [info] If provided paths are partition directories, please set "basePath" in the options of the data source to specify the root directory of the table. If there are multiple root directories, please load them separately and then union them. [info] at scala.Predef$.assert(Predef.scala:170) [info] at org.apache.spark.sql.execution.datasources.PartitioningUtils$.parsePartitions(PartitioningUtils.scala:133) [info] at org.apache.spark.sql.execution.datasources.PartitioningUtils$.parsePartitions(PartitioningUtils.scala:98) [info] at org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex.inferPartitioning(PartitioningAwareFileIndex.scala:156) [info] at org.apache.spark.sql.execution.datasources.InMemoryFileIndex.partitionSpec(InMemoryFileIndex.scala:54) [info] at org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex.partitionSchema(PartitioningAwareFileIndex.scala:55) [info] at org.apache.spark.sql.execution.datasources.DataSource.getOrInferFileFormatSchema(DataSource.scala:133) [info] at org.apache.spark.sql.execution.datasources.DataSource.resolveRelation(DataSource.scala:361) [info] at org.apache.spark.sql.DataFrameReader.load(DataFrameReader.scala:160) [info] at org.apache.spark.sql.DataFrameReader.parquet(DataFrameReader.scala:536) [info] at org.apache.spark.sql.DataFrameReader.parquet(DataFrameReader.scala:520) [info] at org.apache.spark.sql.streaming.FileStreamSinkSuite$$anonfun$8.apply$mcV$sp(FileStreamSinkSuite.scala:292) [info] at org.apache.spark.sql.streaming.FileStreamSinkSuite$$anonfun$8.apply(FileStreamSinkSuite.scala:268) [info] at org.apache.spark.sql.streaming.FileStreamSinkSuite$$anonfun$8.apply(FileStreamSinkSuite.scala:268) ``` ## What changes were proposed in this pull request? This patch alters `InMemoryFileIndex` to filter out these `basePath`s whose ancestor is the streaming metadata dir (`_spark_metadata`). E.g., the following and other similar dir or files will be filtered out: - (introduced by globbing `basePath/*`) - `basePath/_spark_metadata` - (introduced by globbing `basePath/*/*`) - `basePath/_spark_metadata/0` - `basePath/_spark_metadata/1` - ... ## How was this patch tested? Added unit tests Author: Liwei Lin Closes #17346 from lw-lin/filter-metadata. --- .../datasources/InMemoryFileIndex.scala | 13 +++- .../execution/streaming/FileStreamSink.scala | 20 +++++++ .../datasources/FileSourceStrategySuite.scala | 2 +- .../sql/streaming/FileStreamSinkSuite.scala | 59 ++++++++++++++++++- 4 files changed, 90 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala index 9897ab73b0da..91e31650617e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala @@ -27,6 +27,7 @@ import org.apache.hadoop.mapred.{FileInputFormat, JobConf} import org.apache.spark.internal.Logging import org.apache.spark.metrics.source.HiveCatalogMetrics +import org.apache.spark.sql.execution.streaming.FileStreamSink import org.apache.spark.sql.SparkSession import org.apache.spark.sql.types.StructType import org.apache.spark.util.SerializableConfiguration @@ -36,20 +37,28 @@ import org.apache.spark.util.SerializableConfiguration * A [[FileIndex]] that generates the list of files to process by recursively listing all the * files present in `paths`. * - * @param rootPaths the list of root table paths to scan + * @param rootPathsSpecified the list of root table paths to scan (some of which might be + * filtered out later) * @param parameters as set of options to control discovery * @param partitionSchema an optional partition schema that will be use to provide types for the * discovered partitions */ class InMemoryFileIndex( sparkSession: SparkSession, - override val rootPaths: Seq[Path], + rootPathsSpecified: Seq[Path], parameters: Map[String, String], partitionSchema: Option[StructType], fileStatusCache: FileStatusCache = NoopCache) extends PartitioningAwareFileIndex( sparkSession, parameters, partitionSchema, fileStatusCache) { + // Filter out streaming metadata dirs or files such as "/.../_spark_metadata" (the metadata dir) + // or "/.../_spark_metadata/0" (a file in the metadata dir). `rootPathsSpecified` might contain + // such streaming metadata dir or files, e.g. when after globbing "basePath/*" where "basePath" + // is the output of a streaming query. + override val rootPaths = + rootPathsSpecified.filterNot(FileStreamSink.ancestorIsMetadataDirectory(_, hadoopConf)) + @volatile private var cachedLeafFiles: mutable.LinkedHashMap[Path, FileStatus] = _ @volatile private var cachedLeafDirToChildrenFiles: Map[Path, Array[FileStatus]] = _ @volatile private var cachedPartitionSpec: PartitionSpec = _ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala index 07ec4e9429e4..6885d0bf67cc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala @@ -53,6 +53,26 @@ object FileStreamSink extends Logging { case _ => false } } + + /** + * Returns true if the path is the metadata dir or its ancestor is the metadata dir. + * E.g.: + * - ancestorIsMetadataDirectory(/.../_spark_metadata) => true + * - ancestorIsMetadataDirectory(/.../_spark_metadata/0) => true + * - ancestorIsMetadataDirectory(/a/b/c) => false + */ + def ancestorIsMetadataDirectory(path: Path, hadoopConf: Configuration): Boolean = { + val fs = path.getFileSystem(hadoopConf) + var currentPath = path.makeQualified(fs.getUri, fs.getWorkingDirectory) + while (currentPath != null) { + if (currentPath.getName == FileStreamSink.metadataDir) { + return true + } else { + currentPath = currentPath.getParent + } + } + return false + } } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala index 8703fe96e587..fa3c69612704 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala @@ -395,7 +395,7 @@ class FileSourceStrategySuite extends QueryTest with SharedSQLContext with Predi val fileCatalog = new InMemoryFileIndex( sparkSession = spark, - rootPaths = Seq(new Path(tempDir)), + rootPathsSpecified = Seq(new Path(tempDir)), parameters = Map.empty[String, String], partitionSchema = None) // This should not fail. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala index 1211242b9fbb..1a2d3a13f3a4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala @@ -19,10 +19,12 @@ package org.apache.spark.sql.streaming import java.util.Locale +import org.apache.hadoop.fs.Path + import org.apache.spark.sql.{AnalysisException, DataFrame} import org.apache.spark.sql.execution.DataSourceScanExec import org.apache.spark.sql.execution.datasources._ -import org.apache.spark.sql.execution.streaming.{MemoryStream, MetadataLogFileIndex} +import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{IntegerType, StructField, StructType} import org.apache.spark.util.Utils @@ -145,6 +147,43 @@ class FileStreamSinkSuite extends StreamTest { } } + test("partitioned writing and batch reading with 'basePath'") { + withTempDir { outputDir => + withTempDir { checkpointDir => + val outputPath = outputDir.getAbsolutePath + val inputData = MemoryStream[Int] + val ds = inputData.toDS() + + var query: StreamingQuery = null + + try { + query = + ds.map(i => (i, -i, i * 1000)) + .toDF("id1", "id2", "value") + .writeStream + .partitionBy("id1", "id2") + .option("checkpointLocation", checkpointDir.getAbsolutePath) + .format("parquet") + .start(outputPath) + + inputData.addData(1, 2, 3) + failAfter(streamingTimeout) { + query.processAllAvailable() + } + + val readIn = spark.read.option("basePath", outputPath).parquet(s"$outputDir/*/*") + checkDatasetUnorderly( + readIn.as[(Int, Int, Int)], + (1000, 1, -1), (2000, 2, -2), (3000, 3, -3)) + } finally { + if (query != null) { + query.stop() + } + } + } + } + } + // This tests whether FileStreamSink works with aggregations. Specifically, it tests // whether the correct streaming QueryExecution (i.e. IncrementalExecution) is used to // to execute the trigger for writing data to file sink. See SPARK-18440 for more details. @@ -266,4 +305,22 @@ class FileStreamSinkSuite extends StreamTest { } } } + + test("FileStreamSink.ancestorIsMetadataDirectory()") { + val hadoopConf = spark.sparkContext.hadoopConfiguration + def assertAncestorIsMetadataDirectory(path: String): Unit = + assert(FileStreamSink.ancestorIsMetadataDirectory(new Path(path), hadoopConf)) + def assertAncestorIsNotMetadataDirectory(path: String): Unit = + assert(!FileStreamSink.ancestorIsMetadataDirectory(new Path(path), hadoopConf)) + + assertAncestorIsMetadataDirectory(s"/${FileStreamSink.metadataDir}") + assertAncestorIsMetadataDirectory(s"/${FileStreamSink.metadataDir}/") + assertAncestorIsMetadataDirectory(s"/a/${FileStreamSink.metadataDir}") + assertAncestorIsMetadataDirectory(s"/a/${FileStreamSink.metadataDir}/") + assertAncestorIsMetadataDirectory(s"/a/b/${FileStreamSink.metadataDir}/c") + assertAncestorIsMetadataDirectory(s"/a/b/${FileStreamSink.metadataDir}/c/") + + assertAncestorIsNotMetadataDirectory(s"/a/b/c") + assertAncestorIsNotMetadataDirectory(s"/a/b/c/${FileStreamSink.metadataDir}extra") + } } From 13eb37c860c8f672d0e9d9065d0333f981db71e3 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Wed, 3 May 2017 13:08:25 -0700 Subject: [PATCH 416/512] [MINOR][SQL] Fix the test title from =!= to <=>, remove a duplicated test and add a test for =!= ## What changes were proposed in this pull request? This PR proposes three things as below: - This test looks not testing `<=>` and identical with the test above, `===`. So, it removes the test. ```diff - test("<=>") { - checkAnswer( - testData2.filter($"a" === 1), - testData2.collect().toSeq.filter(r => r.getInt(0) == 1)) - - checkAnswer( - testData2.filter($"a" === $"b"), - testData2.collect().toSeq.filter(r => r.getInt(0) == r.getInt(1))) - } ``` - Replace the test title from `=!=` to `<=>`. It looks the test actually testing `<=>`. ```diff + private lazy val nullData = Seq( + (Some(1), Some(1)), (Some(1), Some(2)), (Some(1), None), (None, None)).toDF("a", "b") + ... - test("=!=") { + test("<=>") { - val nullData = spark.createDataFrame(sparkContext.parallelize( - Row(1, 1) :: - Row(1, 2) :: - Row(1, null) :: - Row(null, null) :: Nil), - StructType(Seq(StructField("a", IntegerType), StructField("b", IntegerType)))) - checkAnswer( nullData.filter($"b" <=> 1), ... ``` - Add the tests for `=!=` which looks not existing. ```diff + test("=!=") { + checkAnswer( + nullData.filter($"b" =!= 1), + Row(1, 2) :: Nil) + + checkAnswer(nullData.filter($"b" =!= null), Nil) + + checkAnswer( + nullData.filter($"a" =!= $"b"), + Row(1, 2) :: Nil) + } ``` ## How was this patch tested? Manually running the tests. Author: hyukjinkwon Closes #17842 from HyukjinKwon/minor-test-fix. --- .../spark/sql/ColumnExpressionSuite.scala | 31 +++++++++---------- 1 file changed, 14 insertions(+), 17 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index b0f398dab745..bc708ca88d7e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -39,6 +39,9 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { StructType(Seq(StructField("a", BooleanType), StructField("b", BooleanType)))) } + private lazy val nullData = Seq( + (Some(1), Some(1)), (Some(1), Some(2)), (Some(1), None), (None, None)).toDF("a", "b") + test("column names with space") { val df = Seq((1, "a")).toDF("name with space", "name.with.dot") @@ -283,23 +286,6 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { } test("<=>") { - checkAnswer( - testData2.filter($"a" === 1), - testData2.collect().toSeq.filter(r => r.getInt(0) == 1)) - - checkAnswer( - testData2.filter($"a" === $"b"), - testData2.collect().toSeq.filter(r => r.getInt(0) == r.getInt(1))) - } - - test("=!=") { - val nullData = spark.createDataFrame(sparkContext.parallelize( - Row(1, 1) :: - Row(1, 2) :: - Row(1, null) :: - Row(null, null) :: Nil), - StructType(Seq(StructField("a", IntegerType), StructField("b", IntegerType)))) - checkAnswer( nullData.filter($"b" <=> 1), Row(1, 1) :: Nil) @@ -321,7 +307,18 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { checkAnswer( nullData2.filter($"a" <=> null), Row(null) :: Nil) + } + test("=!=") { + checkAnswer( + nullData.filter($"b" =!= 1), + Row(1, 2) :: Nil) + + checkAnswer(nullData.filter($"b" =!= null), Nil) + + checkAnswer( + nullData.filter($"a" =!= $"b"), + Row(1, 2) :: Nil) } test(">") { From 02bbe73118a39e2fb378aa2002449367a92f6d67 Mon Sep 17 00:00:00 2001 From: zero323 Date: Wed, 3 May 2017 19:15:28 -0700 Subject: [PATCH 417/512] [SPARK-20584][PYSPARK][SQL] Python generic hint support ## What changes were proposed in this pull request? Adds `hint` method to PySpark `DataFrame`. ## How was this patch tested? Unit tests, doctests. Author: zero323 Closes #17850 from zero323/SPARK-20584. --- python/pyspark/sql/dataframe.py | 29 +++++++++++++++++++++++++++++ python/pyspark/sql/tests.py | 16 ++++++++++++++++ 2 files changed, 45 insertions(+) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index ab6d35bfa7c5..7b67985f2b32 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -380,6 +380,35 @@ def withWatermark(self, eventTime, delayThreshold): jdf = self._jdf.withWatermark(eventTime, delayThreshold) return DataFrame(jdf, self.sql_ctx) + @since(2.2) + def hint(self, name, *parameters): + """Specifies some hint on the current DataFrame. + + :param name: A name of the hint. + :param parameters: Optional parameters. + :return: :class:`DataFrame` + + >>> df.join(df2.hint("broadcast"), "name").show() + +----+---+------+ + |name|age|height| + +----+---+------+ + | Bob| 5| 85| + +----+---+------+ + """ + if len(parameters) == 1 and isinstance(parameters[0], list): + parameters = parameters[0] + + if not isinstance(name, str): + raise TypeError("name should be provided as str, got {0}".format(type(name))) + + for p in parameters: + if not isinstance(p, str): + raise TypeError( + "all parameters should be str, got {0} of type {1}".format(p, type(p))) + + jdf = self._jdf.hint(name, self._jseq(parameters)) + return DataFrame(jdf, self.sql_ctx) + @since(1.3) def count(self): """Returns the number of rows in this :class:`DataFrame`. diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index ce4abf8fb7e5..f644624f7f31 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -1906,6 +1906,22 @@ def test_functions_broadcast(self): # planner should not crash without a join broadcast(df1)._jdf.queryExecution().executedPlan() + def test_generic_hints(self): + from pyspark.sql import DataFrame + + df1 = self.spark.range(10e10).toDF("id") + df2 = self.spark.range(10e10).toDF("id") + + self.assertIsInstance(df1.hint("broadcast"), DataFrame) + self.assertIsInstance(df1.hint("broadcast", []), DataFrame) + + # Dummy rules + self.assertIsInstance(df1.hint("broadcast", "foo", "bar"), DataFrame) + self.assertIsInstance(df1.hint("broadcast", ["foo", "bar"]), DataFrame) + + plan = df1.join(df2.hint("broadcast"), "id")._jdf.queryExecution().executedPlan() + self.assertEqual(1, plan.toString().count("BroadcastHashJoin")) + def test_toDF_with_schema_string(self): data = [Row(key=i, value=str(i)) for i in range(100)] rdd = self.sc.parallelize(data, 5) From fc472bddd1d9c6a28e57e31496c0166777af597e Mon Sep 17 00:00:00 2001 From: Felix Cheung Date: Wed, 3 May 2017 21:40:18 -0700 Subject: [PATCH 418/512] [SPARK-20543][SPARKR] skip tests when running on CRAN ## What changes were proposed in this pull request? General rule on skip or not: skip if - RDD tests - tests could run long or complicated (streaming, hivecontext) - tests on error conditions - tests won't likely change/break ## How was this patch tested? unit tests, `R CMD check --as-cran`, `R CMD check` Author: Felix Cheung Closes #17817 from felixcheung/rskiptest. --- R/pkg/inst/tests/testthat/test_Serde.R | 6 + R/pkg/inst/tests/testthat/test_Windows.R | 2 + R/pkg/inst/tests/testthat/test_binaryFile.R | 8 ++ .../tests/testthat/test_binary_function.R | 6 + R/pkg/inst/tests/testthat/test_broadcast.R | 4 + R/pkg/inst/tests/testthat/test_client.R | 8 ++ R/pkg/inst/tests/testthat/test_context.R | 16 +++ .../inst/tests/testthat/test_includePackage.R | 4 + .../tests/testthat/test_mllib_clustering.R | 4 + .../tests/testthat/test_mllib_regression.R | 12 ++ .../tests/testthat/test_parallelize_collect.R | 8 ++ R/pkg/inst/tests/testthat/test_rdd.R | 106 +++++++++++++++++- R/pkg/inst/tests/testthat/test_shuffle.R | 24 ++++ R/pkg/inst/tests/testthat/test_sparkR.R | 2 + R/pkg/inst/tests/testthat/test_sparkSQL.R | 61 +++++++++- R/pkg/inst/tests/testthat/test_streaming.R | 12 ++ R/pkg/inst/tests/testthat/test_take.R | 2 + R/pkg/inst/tests/testthat/test_textFile.R | 18 +++ R/pkg/inst/tests/testthat/test_utils.R | 6 + R/run-tests.sh | 2 +- 20 files changed, 307 insertions(+), 4 deletions(-) diff --git a/R/pkg/inst/tests/testthat/test_Serde.R b/R/pkg/inst/tests/testthat/test_Serde.R index b5f6f1b54fa8..518fb7bd9404 100644 --- a/R/pkg/inst/tests/testthat/test_Serde.R +++ b/R/pkg/inst/tests/testthat/test_Serde.R @@ -20,6 +20,8 @@ context("SerDe functionality") sparkSession <- sparkR.session(enableHiveSupport = FALSE) test_that("SerDe of primitive types", { + skip_on_cran() + x <- callJStatic("SparkRHandler", "echo", 1L) expect_equal(x, 1L) expect_equal(class(x), "integer") @@ -38,6 +40,8 @@ test_that("SerDe of primitive types", { }) test_that("SerDe of list of primitive types", { + skip_on_cran() + x <- list(1L, 2L, 3L) y <- callJStatic("SparkRHandler", "echo", x) expect_equal(x, y) @@ -65,6 +69,8 @@ test_that("SerDe of list of primitive types", { }) test_that("SerDe of list of lists", { + skip_on_cran() + x <- list(list(1L, 2L, 3L), list(1, 2, 3), list(TRUE, FALSE), list("a", "b", "c")) y <- callJStatic("SparkRHandler", "echo", x) diff --git a/R/pkg/inst/tests/testthat/test_Windows.R b/R/pkg/inst/tests/testthat/test_Windows.R index 1d777ddb286d..919b063bf069 100644 --- a/R/pkg/inst/tests/testthat/test_Windows.R +++ b/R/pkg/inst/tests/testthat/test_Windows.R @@ -17,6 +17,8 @@ context("Windows-specific tests") test_that("sparkJars tag in SparkContext", { + skip_on_cran() + if (.Platform$OS.type != "windows") { skip("This test is only for Windows, skipped") } diff --git a/R/pkg/inst/tests/testthat/test_binaryFile.R b/R/pkg/inst/tests/testthat/test_binaryFile.R index b5c279e3156e..63f54e1af02b 100644 --- a/R/pkg/inst/tests/testthat/test_binaryFile.R +++ b/R/pkg/inst/tests/testthat/test_binaryFile.R @@ -24,6 +24,8 @@ sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", mockFile <- c("Spark is pretty.", "Spark is awesome.") test_that("saveAsObjectFile()/objectFile() following textFile() works", { + skip_on_cran() + fileName1 <- tempfile(pattern = "spark-test", fileext = ".tmp") fileName2 <- tempfile(pattern = "spark-test", fileext = ".tmp") writeLines(mockFile, fileName1) @@ -38,6 +40,8 @@ test_that("saveAsObjectFile()/objectFile() following textFile() works", { }) test_that("saveAsObjectFile()/objectFile() works on a parallelized list", { + skip_on_cran() + fileName <- tempfile(pattern = "spark-test", fileext = ".tmp") l <- list(1, 2, 3) @@ -50,6 +54,8 @@ test_that("saveAsObjectFile()/objectFile() works on a parallelized list", { }) test_that("saveAsObjectFile()/objectFile() following RDD transformations works", { + skip_on_cran() + fileName1 <- tempfile(pattern = "spark-test", fileext = ".tmp") fileName2 <- tempfile(pattern = "spark-test", fileext = ".tmp") writeLines(mockFile, fileName1) @@ -74,6 +80,8 @@ test_that("saveAsObjectFile()/objectFile() following RDD transformations works", }) test_that("saveAsObjectFile()/objectFile() works with multiple paths", { + skip_on_cran() + fileName1 <- tempfile(pattern = "spark-test", fileext = ".tmp") fileName2 <- tempfile(pattern = "spark-test", fileext = ".tmp") diff --git a/R/pkg/inst/tests/testthat/test_binary_function.R b/R/pkg/inst/tests/testthat/test_binary_function.R index 59cb2e620440..25bb2b84266d 100644 --- a/R/pkg/inst/tests/testthat/test_binary_function.R +++ b/R/pkg/inst/tests/testthat/test_binary_function.R @@ -29,6 +29,8 @@ rdd <- parallelize(sc, nums, 2L) mockFile <- c("Spark is pretty.", "Spark is awesome.") test_that("union on two RDDs", { + skip_on_cran() + actual <- collectRDD(unionRDD(rdd, rdd)) expect_equal(actual, as.list(rep(nums, 2))) @@ -51,6 +53,8 @@ test_that("union on two RDDs", { }) test_that("cogroup on two RDDs", { + skip_on_cran() + rdd1 <- parallelize(sc, list(list(1, 1), list(2, 4))) rdd2 <- parallelize(sc, list(list(1, 2), list(1, 3))) cogroup.rdd <- cogroup(rdd1, rdd2, numPartitions = 2L) @@ -69,6 +73,8 @@ test_that("cogroup on two RDDs", { }) test_that("zipPartitions() on RDDs", { + skip_on_cran() + rdd1 <- parallelize(sc, 1:2, 2L) # 1, 2 rdd2 <- parallelize(sc, 1:4, 2L) # 1:2, 3:4 rdd3 <- parallelize(sc, 1:6, 2L) # 1:3, 4:6 diff --git a/R/pkg/inst/tests/testthat/test_broadcast.R b/R/pkg/inst/tests/testthat/test_broadcast.R index 65f204d096f4..504ded4fc862 100644 --- a/R/pkg/inst/tests/testthat/test_broadcast.R +++ b/R/pkg/inst/tests/testthat/test_broadcast.R @@ -26,6 +26,8 @@ nums <- 1:2 rrdd <- parallelize(sc, nums, 2L) test_that("using broadcast variable", { + skip_on_cran() + randomMat <- matrix(nrow = 10, ncol = 10, data = rnorm(100)) randomMatBr <- broadcast(sc, randomMat) @@ -38,6 +40,8 @@ test_that("using broadcast variable", { }) test_that("without using broadcast variable", { + skip_on_cran() + randomMat <- matrix(nrow = 10, ncol = 10, data = rnorm(100)) useBroadcast <- function(x) { diff --git a/R/pkg/inst/tests/testthat/test_client.R b/R/pkg/inst/tests/testthat/test_client.R index 0cf25fe1dbf3..3d53bebab630 100644 --- a/R/pkg/inst/tests/testthat/test_client.R +++ b/R/pkg/inst/tests/testthat/test_client.R @@ -18,6 +18,8 @@ context("functions in client.R") test_that("adding spark-testing-base as a package works", { + skip_on_cran() + args <- generateSparkSubmitArgs("", "", "", "", "holdenk:spark-testing-base:1.3.0_0.0.5") expect_equal(gsub("[[:space:]]", "", args), @@ -26,16 +28,22 @@ test_that("adding spark-testing-base as a package works", { }) test_that("no package specified doesn't add packages flag", { + skip_on_cran() + args <- generateSparkSubmitArgs("", "", "", "", "") expect_equal(gsub("[[:space:]]", "", args), "") }) test_that("multiple packages don't produce a warning", { + skip_on_cran() + expect_warning(generateSparkSubmitArgs("", "", "", "", c("A", "B")), NA) }) test_that("sparkJars sparkPackages as character vectors", { + skip_on_cran() + args <- generateSparkSubmitArgs("", "", c("one.jar", "two.jar", "three.jar"), "", c("com.databricks:spark-avro_2.10:2.0.1")) expect_match(args, "--jars one.jar,two.jar,three.jar") diff --git a/R/pkg/inst/tests/testthat/test_context.R b/R/pkg/inst/tests/testthat/test_context.R index c64fe6edcd49..632a90d68177 100644 --- a/R/pkg/inst/tests/testthat/test_context.R +++ b/R/pkg/inst/tests/testthat/test_context.R @@ -18,6 +18,8 @@ context("test functions in sparkR.R") test_that("Check masked functions", { + skip_on_cran() + # Check that we are not masking any new function from base, stats, testthat unexpectedly # NOTE: We should avoid adding entries to *namesOfMaskedCompletely* as masked functions make it # hard for users to use base R functions. Please check when in doubt. @@ -55,6 +57,8 @@ test_that("Check masked functions", { }) test_that("repeatedly starting and stopping SparkR", { + skip_on_cran() + for (i in 1:4) { sc <- suppressWarnings(sparkR.init()) rdd <- parallelize(sc, 1:20, 2L) @@ -73,6 +77,8 @@ test_that("repeatedly starting and stopping SparkSession", { }) test_that("rdd GC across sparkR.stop", { + skip_on_cran() + sc <- sparkR.sparkContext() # sc should get id 0 rdd1 <- parallelize(sc, 1:20, 2L) # rdd1 should get id 1 rdd2 <- parallelize(sc, 1:10, 2L) # rdd2 should get id 2 @@ -96,6 +102,8 @@ test_that("rdd GC across sparkR.stop", { }) test_that("job group functions can be called", { + skip_on_cran() + sc <- sparkR.sparkContext() setJobGroup("groupId", "job description", TRUE) cancelJobGroup("groupId") @@ -108,12 +116,16 @@ test_that("job group functions can be called", { }) test_that("utility function can be called", { + skip_on_cran() + sparkR.sparkContext() setLogLevel("ERROR") sparkR.session.stop() }) test_that("getClientModeSparkSubmitOpts() returns spark-submit args from whitelist", { + skip_on_cran() + e <- new.env() e[["spark.driver.memory"]] <- "512m" ops <- getClientModeSparkSubmitOpts("sparkrmain", e) @@ -141,6 +153,8 @@ test_that("getClientModeSparkSubmitOpts() returns spark-submit args from whiteli }) test_that("sparkJars sparkPackages as comma-separated strings", { + skip_on_cran() + expect_warning(processSparkJars(" a, b ")) jars <- suppressWarnings(processSparkJars(" a, b ")) expect_equal(lapply(jars, basename), list("a", "b")) @@ -168,6 +182,8 @@ test_that("spark.lapply should perform simple transforms", { }) test_that("add and get file to be downloaded with Spark job on every node", { + skip_on_cran() + sparkR.sparkContext() # Test add file. path <- tempfile(pattern = "hello", fileext = ".txt") diff --git a/R/pkg/inst/tests/testthat/test_includePackage.R b/R/pkg/inst/tests/testthat/test_includePackage.R index 563ea298c2dd..f823ad8e9c98 100644 --- a/R/pkg/inst/tests/testthat/test_includePackage.R +++ b/R/pkg/inst/tests/testthat/test_includePackage.R @@ -26,6 +26,8 @@ nums <- 1:2 rdd <- parallelize(sc, nums, 2L) test_that("include inside function", { + skip_on_cran() + # Only run the test if plyr is installed. if ("plyr" %in% rownames(installed.packages())) { suppressPackageStartupMessages(library(plyr)) @@ -42,6 +44,8 @@ test_that("include inside function", { }) test_that("use include package", { + skip_on_cran() + # Only run the test if plyr is installed. if ("plyr" %in% rownames(installed.packages())) { suppressPackageStartupMessages(library(plyr)) diff --git a/R/pkg/inst/tests/testthat/test_mllib_clustering.R b/R/pkg/inst/tests/testthat/test_mllib_clustering.R index 1661e987b730..478012e8828c 100644 --- a/R/pkg/inst/tests/testthat/test_mllib_clustering.R +++ b/R/pkg/inst/tests/testthat/test_mllib_clustering.R @@ -255,6 +255,8 @@ test_that("spark.lda with libsvm", { }) test_that("spark.lda with text input", { + skip_on_cran() + text <- read.text(absoluteSparkPath("data/mllib/sample_lda_data.txt")) model <- spark.lda(text, optimizer = "online", features = "value") @@ -297,6 +299,8 @@ test_that("spark.lda with text input", { }) test_that("spark.posterior and spark.perplexity", { + skip_on_cran() + text <- read.text(absoluteSparkPath("data/mllib/sample_lda_data.txt")) model <- spark.lda(text, features = "value", k = 3) diff --git a/R/pkg/inst/tests/testthat/test_mllib_regression.R b/R/pkg/inst/tests/testthat/test_mllib_regression.R index 3e9ad7719807..58924f952c6b 100644 --- a/R/pkg/inst/tests/testthat/test_mllib_regression.R +++ b/R/pkg/inst/tests/testthat/test_mllib_regression.R @@ -23,6 +23,8 @@ context("MLlib regression algorithms, except for tree-based algorithms") sparkSession <- sparkR.session(enableHiveSupport = FALSE) test_that("formula of spark.glm", { + skip_on_cran() + training <- suppressWarnings(createDataFrame(iris)) # directly calling the spark API # dot minus and intercept vs native glm @@ -195,6 +197,8 @@ test_that("spark.glm summary", { }) test_that("spark.glm save/load", { + skip_on_cran() + training <- suppressWarnings(createDataFrame(iris)) m <- spark.glm(training, Sepal_Width ~ Sepal_Length + Species) s <- summary(m) @@ -222,6 +226,8 @@ test_that("spark.glm save/load", { }) test_that("formula of glm", { + skip_on_cran() + training <- suppressWarnings(createDataFrame(iris)) # dot minus and intercept vs native glm model <- glm(Sepal_Width ~ . - Species + 0, data = training) @@ -248,6 +254,8 @@ test_that("formula of glm", { }) test_that("glm and predict", { + skip_on_cran() + training <- suppressWarnings(createDataFrame(iris)) # gaussian family model <- glm(Sepal_Width ~ Sepal_Length + Species, data = training) @@ -292,6 +300,8 @@ test_that("glm and predict", { }) test_that("glm summary", { + skip_on_cran() + # gaussian family training <- suppressWarnings(createDataFrame(iris)) stats <- summary(glm(Sepal_Width ~ Sepal_Length + Species, data = training)) @@ -341,6 +351,8 @@ test_that("glm summary", { }) test_that("glm save/load", { + skip_on_cran() + training <- suppressWarnings(createDataFrame(iris)) m <- glm(Sepal_Width ~ Sepal_Length + Species, data = training) s <- summary(m) diff --git a/R/pkg/inst/tests/testthat/test_parallelize_collect.R b/R/pkg/inst/tests/testthat/test_parallelize_collect.R index 55972e1ba469..1f7f387de08c 100644 --- a/R/pkg/inst/tests/testthat/test_parallelize_collect.R +++ b/R/pkg/inst/tests/testthat/test_parallelize_collect.R @@ -39,6 +39,8 @@ jsc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", # Tests test_that("parallelize() on simple vectors and lists returns an RDD", { + skip_on_cran() + numVectorRDD <- parallelize(jsc, numVector, 1) numVectorRDD2 <- parallelize(jsc, numVector, 10) numListRDD <- parallelize(jsc, numList, 1) @@ -66,6 +68,8 @@ test_that("parallelize() on simple vectors and lists returns an RDD", { }) test_that("collect(), following a parallelize(), gives back the original collections", { + skip_on_cran() + numVectorRDD <- parallelize(jsc, numVector, 10) expect_equal(collectRDD(numVectorRDD), as.list(numVector)) @@ -86,6 +90,8 @@ test_that("collect(), following a parallelize(), gives back the original collect }) test_that("regression: collect() following a parallelize() does not drop elements", { + skip_on_cran() + # 10 %/% 6 = 1, ceiling(10 / 6) = 2 collLen <- 10 numPart <- 6 @@ -95,6 +101,8 @@ test_that("regression: collect() following a parallelize() does not drop element }) test_that("parallelize() and collect() work for lists of pairs (pairwise data)", { + skip_on_cran() + # use the pairwise logical to indicate pairwise data numPairsRDDD1 <- parallelize(jsc, numPairs, 1) numPairsRDDD2 <- parallelize(jsc, numPairs, 2) diff --git a/R/pkg/inst/tests/testthat/test_rdd.R b/R/pkg/inst/tests/testthat/test_rdd.R index b72c801dd958..a3b1631e1d11 100644 --- a/R/pkg/inst/tests/testthat/test_rdd.R +++ b/R/pkg/inst/tests/testthat/test_rdd.R @@ -29,22 +29,30 @@ intPairs <- list(list(1L, -1), list(2L, 100), list(2L, 1), list(1L, 200)) intRdd <- parallelize(sc, intPairs, 2L) test_that("get number of partitions in RDD", { + skip_on_cran() + expect_equal(getNumPartitionsRDD(rdd), 2) expect_equal(getNumPartitionsRDD(intRdd), 2) }) test_that("first on RDD", { + skip_on_cran() + expect_equal(firstRDD(rdd), 1) newrdd <- lapply(rdd, function(x) x + 1) expect_equal(firstRDD(newrdd), 2) }) test_that("count and length on RDD", { - expect_equal(countRDD(rdd), 10) - expect_equal(lengthRDD(rdd), 10) + skip_on_cran() + + expect_equal(countRDD(rdd), 10) + expect_equal(lengthRDD(rdd), 10) }) test_that("count by values and keys", { + skip_on_cran() + mods <- lapply(rdd, function(x) { x %% 3 }) actual <- countByValue(mods) expected <- list(list(0, 3L), list(1, 4L), list(2, 3L)) @@ -56,30 +64,40 @@ test_that("count by values and keys", { }) test_that("lapply on RDD", { + skip_on_cran() + multiples <- lapply(rdd, function(x) { 2 * x }) actual <- collectRDD(multiples) expect_equal(actual, as.list(nums * 2)) }) test_that("lapplyPartition on RDD", { + skip_on_cran() + sums <- lapplyPartition(rdd, function(part) { sum(unlist(part)) }) actual <- collectRDD(sums) expect_equal(actual, list(15, 40)) }) test_that("mapPartitions on RDD", { + skip_on_cran() + sums <- mapPartitions(rdd, function(part) { sum(unlist(part)) }) actual <- collectRDD(sums) expect_equal(actual, list(15, 40)) }) test_that("flatMap() on RDDs", { + skip_on_cran() + flat <- flatMap(intRdd, function(x) { list(x, x) }) actual <- collectRDD(flat) expect_equal(actual, rep(intPairs, each = 2)) }) test_that("filterRDD on RDD", { + skip_on_cran() + filtered.rdd <- filterRDD(rdd, function(x) { x %% 2 == 0 }) actual <- collectRDD(filtered.rdd) expect_equal(actual, list(2, 4, 6, 8, 10)) @@ -95,6 +113,8 @@ test_that("filterRDD on RDD", { }) test_that("lookup on RDD", { + skip_on_cran() + vals <- lookup(intRdd, 1L) expect_equal(vals, list(-1, 200)) @@ -103,6 +123,8 @@ test_that("lookup on RDD", { }) test_that("several transformations on RDD (a benchmark on PipelinedRDD)", { + skip_on_cran() + rdd2 <- rdd for (i in 1:12) rdd2 <- lapplyPartitionsWithIndex( @@ -117,6 +139,8 @@ test_that("several transformations on RDD (a benchmark on PipelinedRDD)", { }) test_that("PipelinedRDD support actions: cache(), persist(), unpersist(), checkpoint()", { + skip_on_cran() + # RDD rdd2 <- rdd # PipelinedRDD @@ -158,6 +182,8 @@ test_that("PipelinedRDD support actions: cache(), persist(), unpersist(), checkp }) test_that("reduce on RDD", { + skip_on_cran() + sum <- reduce(rdd, "+") expect_equal(sum, 55) @@ -167,6 +193,8 @@ test_that("reduce on RDD", { }) test_that("lapply with dependency", { + skip_on_cran() + fa <- 5 multiples <- lapply(rdd, function(x) { fa * x }) actual <- collectRDD(multiples) @@ -175,6 +203,8 @@ test_that("lapply with dependency", { }) test_that("lapplyPartitionsWithIndex on RDDs", { + skip_on_cran() + func <- function(partIndex, part) { list(partIndex, Reduce("+", part)) } actual <- collectRDD(lapplyPartitionsWithIndex(rdd, func), flatten = FALSE) expect_equal(actual, list(list(0, 15), list(1, 40))) @@ -191,10 +221,14 @@ test_that("lapplyPartitionsWithIndex on RDDs", { }) test_that("sampleRDD() on RDDs", { + skip_on_cran() + expect_equal(unlist(collectRDD(sampleRDD(rdd, FALSE, 1.0, 2014L))), nums) }) test_that("takeSample() on RDDs", { + skip_on_cran() + # ported from RDDSuite.scala, modified seeds data <- parallelize(sc, 1:100, 2L) for (seed in 4:5) { @@ -237,6 +271,8 @@ test_that("takeSample() on RDDs", { }) test_that("mapValues() on pairwise RDDs", { + skip_on_cran() + multiples <- mapValues(intRdd, function(x) { x * 2 }) actual <- collectRDD(multiples) expected <- lapply(intPairs, function(x) { @@ -246,6 +282,8 @@ test_that("mapValues() on pairwise RDDs", { }) test_that("flatMapValues() on pairwise RDDs", { + skip_on_cran() + l <- parallelize(sc, list(list(1, c(1, 2)), list(2, c(3, 4)))) actual <- collectRDD(flatMapValues(l, function(x) { x })) expect_equal(actual, list(list(1, 1), list(1, 2), list(2, 3), list(2, 4))) @@ -258,6 +296,8 @@ test_that("flatMapValues() on pairwise RDDs", { }) test_that("reduceByKeyLocally() on PairwiseRDDs", { + skip_on_cran() + pairs <- parallelize(sc, list(list(1, 2), list(1.1, 3), list(1, 4)), 2L) actual <- reduceByKeyLocally(pairs, "+") expect_equal(sortKeyValueList(actual), @@ -271,6 +311,8 @@ test_that("reduceByKeyLocally() on PairwiseRDDs", { }) test_that("distinct() on RDDs", { + skip_on_cran() + nums.rep2 <- rep(1:10, 2) rdd.rep2 <- parallelize(sc, nums.rep2, 2L) uniques <- distinctRDD(rdd.rep2) @@ -279,21 +321,29 @@ test_that("distinct() on RDDs", { }) test_that("maximum() on RDDs", { + skip_on_cran() + max <- maximum(rdd) expect_equal(max, 10) }) test_that("minimum() on RDDs", { + skip_on_cran() + min <- minimum(rdd) expect_equal(min, 1) }) test_that("sumRDD() on RDDs", { + skip_on_cran() + sum <- sumRDD(rdd) expect_equal(sum, 55) }) test_that("keyBy on RDDs", { + skip_on_cran() + func <- function(x) { x * x } keys <- keyBy(rdd, func) actual <- collectRDD(keys) @@ -301,6 +351,8 @@ test_that("keyBy on RDDs", { }) test_that("repartition/coalesce on RDDs", { + skip_on_cran() + rdd <- parallelize(sc, 1:20, 4L) # each partition contains 5 elements # repartition @@ -322,6 +374,8 @@ test_that("repartition/coalesce on RDDs", { }) test_that("sortBy() on RDDs", { + skip_on_cran() + sortedRdd <- sortBy(rdd, function(x) { x * x }, ascending = FALSE) actual <- collectRDD(sortedRdd) expect_equal(actual, as.list(sort(nums, decreasing = TRUE))) @@ -333,6 +387,8 @@ test_that("sortBy() on RDDs", { }) test_that("takeOrdered() on RDDs", { + skip_on_cran() + l <- list(10, 1, 2, 9, 3, 4, 5, 6, 7) rdd <- parallelize(sc, l) actual <- takeOrdered(rdd, 6L) @@ -345,6 +401,8 @@ test_that("takeOrdered() on RDDs", { }) test_that("top() on RDDs", { + skip_on_cran() + l <- list(10, 1, 2, 9, 3, 4, 5, 6, 7) rdd <- parallelize(sc, l) actual <- top(rdd, 6L) @@ -357,6 +415,8 @@ test_that("top() on RDDs", { }) test_that("fold() on RDDs", { + skip_on_cran() + actual <- fold(rdd, 0, "+") expect_equal(actual, Reduce("+", nums, 0)) @@ -366,6 +426,8 @@ test_that("fold() on RDDs", { }) test_that("aggregateRDD() on RDDs", { + skip_on_cran() + rdd <- parallelize(sc, list(1, 2, 3, 4)) zeroValue <- list(0, 0) seqOp <- function(x, y) { list(x[[1]] + y, x[[2]] + 1) } @@ -379,6 +441,8 @@ test_that("aggregateRDD() on RDDs", { }) test_that("zipWithUniqueId() on RDDs", { + skip_on_cran() + rdd <- parallelize(sc, list("a", "b", "c", "d", "e"), 3L) actual <- collectRDD(zipWithUniqueId(rdd)) expected <- list(list("a", 0), list("b", 1), list("c", 4), @@ -393,6 +457,8 @@ test_that("zipWithUniqueId() on RDDs", { }) test_that("zipWithIndex() on RDDs", { + skip_on_cran() + rdd <- parallelize(sc, list("a", "b", "c", "d", "e"), 3L) actual <- collectRDD(zipWithIndex(rdd)) expected <- list(list("a", 0), list("b", 1), list("c", 2), @@ -407,24 +473,32 @@ test_that("zipWithIndex() on RDDs", { }) test_that("glom() on RDD", { + skip_on_cran() + rdd <- parallelize(sc, as.list(1:4), 2L) actual <- collectRDD(glom(rdd)) expect_equal(actual, list(list(1, 2), list(3, 4))) }) test_that("keys() on RDDs", { + skip_on_cran() + keys <- keys(intRdd) actual <- collectRDD(keys) expect_equal(actual, lapply(intPairs, function(x) { x[[1]] })) }) test_that("values() on RDDs", { + skip_on_cran() + values <- values(intRdd) actual <- collectRDD(values) expect_equal(actual, lapply(intPairs, function(x) { x[[2]] })) }) test_that("pipeRDD() on RDDs", { + skip_on_cran() + actual <- collectRDD(pipeRDD(rdd, "more")) expected <- as.list(as.character(1:10)) expect_equal(actual, expected) @@ -442,6 +516,8 @@ test_that("pipeRDD() on RDDs", { }) test_that("zipRDD() on RDDs", { + skip_on_cran() + rdd1 <- parallelize(sc, 0:4, 2) rdd2 <- parallelize(sc, 1000:1004, 2) actual <- collectRDD(zipRDD(rdd1, rdd2)) @@ -471,6 +547,8 @@ test_that("zipRDD() on RDDs", { }) test_that("cartesian() on RDDs", { + skip_on_cran() + rdd <- parallelize(sc, 1:3) actual <- collectRDD(cartesian(rdd, rdd)) expect_equal(sortKeyValueList(actual), @@ -514,6 +592,8 @@ test_that("cartesian() on RDDs", { }) test_that("subtract() on RDDs", { + skip_on_cran() + l <- list(1, 1, 2, 2, 3, 4) rdd1 <- parallelize(sc, l) @@ -541,6 +621,8 @@ test_that("subtract() on RDDs", { }) test_that("subtractByKey() on pairwise RDDs", { + skip_on_cran() + l <- list(list("a", 1), list("b", 4), list("b", 5), list("a", 2)) rdd1 <- parallelize(sc, l) @@ -570,6 +652,8 @@ test_that("subtractByKey() on pairwise RDDs", { }) test_that("intersection() on RDDs", { + skip_on_cran() + # intersection with self actual <- collectRDD(intersection(rdd, rdd)) expect_equal(sort(as.integer(actual)), nums) @@ -586,6 +670,8 @@ test_that("intersection() on RDDs", { }) test_that("join() on pairwise RDDs", { + skip_on_cran() + rdd1 <- parallelize(sc, list(list(1, 1), list(2, 4))) rdd2 <- parallelize(sc, list(list(1, 2), list(1, 3))) actual <- collectRDD(joinRDD(rdd1, rdd2, 2L)) @@ -610,6 +696,8 @@ test_that("join() on pairwise RDDs", { }) test_that("leftOuterJoin() on pairwise RDDs", { + skip_on_cran() + rdd1 <- parallelize(sc, list(list(1, 1), list(2, 4))) rdd2 <- parallelize(sc, list(list(1, 2), list(1, 3))) actual <- collectRDD(leftOuterJoin(rdd1, rdd2, 2L)) @@ -640,6 +728,8 @@ test_that("leftOuterJoin() on pairwise RDDs", { }) test_that("rightOuterJoin() on pairwise RDDs", { + skip_on_cran() + rdd1 <- parallelize(sc, list(list(1, 2), list(1, 3))) rdd2 <- parallelize(sc, list(list(1, 1), list(2, 4))) actual <- collectRDD(rightOuterJoin(rdd1, rdd2, 2L)) @@ -667,6 +757,8 @@ test_that("rightOuterJoin() on pairwise RDDs", { }) test_that("fullOuterJoin() on pairwise RDDs", { + skip_on_cran() + rdd1 <- parallelize(sc, list(list(1, 2), list(1, 3), list(3, 3))) rdd2 <- parallelize(sc, list(list(1, 1), list(2, 4))) actual <- collectRDD(fullOuterJoin(rdd1, rdd2, 2L)) @@ -698,6 +790,8 @@ test_that("fullOuterJoin() on pairwise RDDs", { }) test_that("sortByKey() on pairwise RDDs", { + skip_on_cran() + numPairsRdd <- map(rdd, function(x) { list (x, x) }) sortedRdd <- sortByKey(numPairsRdd, ascending = FALSE) actual <- collectRDD(sortedRdd) @@ -747,6 +841,8 @@ test_that("sortByKey() on pairwise RDDs", { }) test_that("collectAsMap() on a pairwise RDD", { + skip_on_cran() + rdd <- parallelize(sc, list(list(1, 2), list(3, 4))) vals <- collectAsMap(rdd) expect_equal(vals, list(`1` = 2, `3` = 4)) @@ -765,11 +861,15 @@ test_that("collectAsMap() on a pairwise RDD", { }) test_that("show()", { + skip_on_cran() + rdd <- parallelize(sc, list(1:10)) expect_output(showRDD(rdd), "ParallelCollectionRDD\\[\\d+\\] at parallelize at RRDD\\.scala:\\d+") }) test_that("sampleByKey() on pairwise RDDs", { + skip_on_cran() + rdd <- parallelize(sc, 1:2000) pairsRDD <- lapply(rdd, function(x) { if (x %% 2 == 0) list("a", x) else list("b", x) }) fractions <- list(a = 0.2, b = 0.1) @@ -794,6 +894,8 @@ test_that("sampleByKey() on pairwise RDDs", { }) test_that("Test correct concurrency of RRDD.compute()", { + skip_on_cran() + rdd <- parallelize(sc, 1:1000, 100) jrdd <- getJRDD(lapply(rdd, function(x) { x }), "row") zrdd <- callJMethod(jrdd, "zip", jrdd) diff --git a/R/pkg/inst/tests/testthat/test_shuffle.R b/R/pkg/inst/tests/testthat/test_shuffle.R index d38efab0fd1d..cedf4f100c6c 100644 --- a/R/pkg/inst/tests/testthat/test_shuffle.R +++ b/R/pkg/inst/tests/testthat/test_shuffle.R @@ -37,6 +37,8 @@ strList <- list("Dexter Morgan: Blood. Sometimes it sets my teeth on edge and ", strListRDD <- parallelize(sc, strList, 4) test_that("groupByKey for integers", { + skip_on_cran() + grouped <- groupByKey(intRdd, 2L) actual <- collectRDD(grouped) @@ -46,6 +48,8 @@ test_that("groupByKey for integers", { }) test_that("groupByKey for doubles", { + skip_on_cran() + grouped <- groupByKey(doubleRdd, 2L) actual <- collectRDD(grouped) @@ -55,6 +59,8 @@ test_that("groupByKey for doubles", { }) test_that("reduceByKey for ints", { + skip_on_cran() + reduced <- reduceByKey(intRdd, "+", 2L) actual <- collectRDD(reduced) @@ -64,6 +70,8 @@ test_that("reduceByKey for ints", { }) test_that("reduceByKey for doubles", { + skip_on_cran() + reduced <- reduceByKey(doubleRdd, "+", 2L) actual <- collectRDD(reduced) @@ -72,6 +80,8 @@ test_that("reduceByKey for doubles", { }) test_that("combineByKey for ints", { + skip_on_cran() + reduced <- combineByKey(intRdd, function(x) { x }, "+", "+", 2L) actual <- collectRDD(reduced) @@ -81,6 +91,8 @@ test_that("combineByKey for ints", { }) test_that("combineByKey for doubles", { + skip_on_cran() + reduced <- combineByKey(doubleRdd, function(x) { x }, "+", "+", 2L) actual <- collectRDD(reduced) @@ -89,6 +101,8 @@ test_that("combineByKey for doubles", { }) test_that("combineByKey for characters", { + skip_on_cran() + stringKeyRDD <- parallelize(sc, list(list("max", 1L), list("min", 2L), list("other", 3L), list("max", 4L)), 2L) @@ -101,6 +115,8 @@ test_that("combineByKey for characters", { }) test_that("aggregateByKey", { + skip_on_cran() + # test aggregateByKey for int keys rdd <- parallelize(sc, list(list(1, 1), list(1, 2), list(2, 3), list(2, 4))) @@ -129,6 +145,8 @@ test_that("aggregateByKey", { }) test_that("foldByKey", { + skip_on_cran() + # test foldByKey for int keys folded <- foldByKey(intRdd, 0, "+", 2L) @@ -172,6 +190,8 @@ test_that("foldByKey", { }) test_that("partitionBy() partitions data correctly", { + skip_on_cran() + # Partition by magnitude partitionByMagnitude <- function(key) { if (key >= 3) 1 else 0 } @@ -187,6 +207,8 @@ test_that("partitionBy() partitions data correctly", { }) test_that("partitionBy works with dependencies", { + skip_on_cran() + kOne <- 1 partitionByParity <- function(key) { if (key %% 2 == kOne) 7 else 4 } @@ -205,6 +227,8 @@ test_that("partitionBy works with dependencies", { }) test_that("test partitionBy with string keys", { + skip_on_cran() + words <- flatMap(strListRDD, function(line) { strsplit(line, " ")[[1]] }) wordCount <- lapply(words, function(word) { list(word, 1L) }) diff --git a/R/pkg/inst/tests/testthat/test_sparkR.R b/R/pkg/inst/tests/testthat/test_sparkR.R index f73fc6baecce..a40981c188f7 100644 --- a/R/pkg/inst/tests/testthat/test_sparkR.R +++ b/R/pkg/inst/tests/testthat/test_sparkR.R @@ -18,6 +18,8 @@ context("functions in sparkR.R") test_that("sparkCheckInstall", { + skip_on_cran() + # "local, yarn-client, mesos-client" mode, SPARK_HOME was set correctly, # and the SparkR job was submitted by "spark-submit" sparkHome <- paste0(tempdir(), "/", "sparkHome") diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index 12867c15d1f9..a7bb3265d92d 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -97,15 +97,21 @@ mapTypeJsonPath <- tempfile(pattern = "sparkr-test", fileext = ".tmp") writeLines(mockLinesMapType, mapTypeJsonPath) test_that("calling sparkRSQL.init returns existing SQL context", { + skip_on_cran() + sqlContext <- suppressWarnings(sparkRSQL.init(sc)) expect_equal(suppressWarnings(sparkRSQL.init(sc)), sqlContext) }) test_that("calling sparkRSQL.init returns existing SparkSession", { + skip_on_cran() + expect_equal(suppressWarnings(sparkRSQL.init(sc)), sparkSession) }) test_that("calling sparkR.session returns existing SparkSession", { + skip_on_cran() + expect_equal(sparkR.session(), sparkSession) }) @@ -203,6 +209,8 @@ test_that("structField type strings", { }) test_that("create DataFrame from RDD", { + skip_on_cran() + rdd <- lapply(parallelize(sc, 1:10), function(x) { list(x, as.character(x)) }) df <- createDataFrame(rdd, list("a", "b")) dfAsDF <- as.DataFrame(rdd, list("a", "b")) @@ -300,6 +308,8 @@ test_that("create DataFrame from RDD", { }) test_that("createDataFrame uses files for large objects", { + skip_on_cran() + # To simulate a large file scenario, we set spark.r.maxAllocationLimit to a smaller value conf <- callJMethod(sparkSession, "conf") callJMethod(conf, "set", "spark.r.maxAllocationLimit", "100") @@ -360,6 +370,8 @@ test_that("read/write csv as DataFrame", { }) test_that("Support other types for options", { + skip_on_cran() + csvPath <- tempfile(pattern = "sparkr-test", fileext = ".csv") mockLinesCsv <- c("year,make,model,comment,blank", "\"2012\",\"Tesla\",\"S\",\"No comment\",", @@ -414,6 +426,8 @@ test_that("convert NAs to null type in DataFrames", { }) test_that("toDF", { + skip_on_cran() + rdd <- lapply(parallelize(sc, 1:10), function(x) { list(x, as.character(x)) }) df <- toDF(rdd, list("a", "b")) expect_is(df, "SparkDataFrame") @@ -525,6 +539,8 @@ test_that("create DataFrame with complex types", { }) test_that("create DataFrame from a data.frame with complex types", { + skip_on_cran() + ldf <- data.frame(row.names = 1:2) ldf$a_list <- list(list(1, 2), list(3, 4)) ldf$an_envir <- c(as.environment(list(a = 1, b = 2)), as.environment(list(c = 3))) @@ -537,6 +553,8 @@ test_that("create DataFrame from a data.frame with complex types", { }) test_that("Collect DataFrame with complex types", { + skip_on_cran() + # ArrayType df <- read.json(complexTypeJsonPath) ldf <- collect(df) @@ -624,6 +642,8 @@ test_that("read/write json files", { }) test_that("read/write json files - compression option", { + skip_on_cran() + df <- read.df(jsonPath, "json") jsonPath <- tempfile(pattern = "jsonPath", fileext = ".json") @@ -637,6 +657,8 @@ test_that("read/write json files - compression option", { }) test_that("jsonRDD() on a RDD with json string", { + skip_on_cran() + sqlContext <- suppressWarnings(sparkRSQL.init(sc)) rdd <- parallelize(sc, mockLines) expect_equal(countRDD(rdd), 3) @@ -693,6 +715,8 @@ test_that( }) test_that("test cache, uncache and clearCache", { + skip_on_cran() + df <- read.json(jsonPath) createOrReplaceTempView(df, "table1") cacheTable("table1") @@ -746,6 +770,8 @@ test_that("tableToDF() returns a new DataFrame", { }) test_that("toRDD() returns an RRDD", { + skip_on_cran() + df <- read.json(jsonPath) testRDD <- toRDD(df) expect_is(testRDD, "RDD") @@ -753,6 +779,8 @@ test_that("toRDD() returns an RRDD", { }) test_that("union on two RDDs created from DataFrames returns an RRDD", { + skip_on_cran() + df <- read.json(jsonPath) RDD1 <- toRDD(df) RDD2 <- toRDD(df) @@ -763,6 +791,8 @@ test_that("union on two RDDs created from DataFrames returns an RRDD", { }) test_that("union on mixed serialization types correctly returns a byte RRDD", { + skip_on_cran() + # Byte RDD nums <- 1:10 rdd <- parallelize(sc, nums, 2L) @@ -792,6 +822,8 @@ test_that("union on mixed serialization types correctly returns a byte RRDD", { }) test_that("objectFile() works with row serialization", { + skip_on_cran() + objectPath <- tempfile(pattern = "spark-test", fileext = ".tmp") df <- read.json(jsonPath) dfRDD <- toRDD(df) @@ -804,6 +836,8 @@ test_that("objectFile() works with row serialization", { }) test_that("lapply() on a DataFrame returns an RDD with the correct columns", { + skip_on_cran() + df <- read.json(jsonPath) testRDD <- lapply(df, function(row) { row$newCol <- row$age + 5 @@ -872,6 +906,8 @@ test_that("collect() support Unicode characters", { }) test_that("multiple pipeline transformations result in an RDD with the correct values", { + skip_on_cran() + df <- read.json(jsonPath) first <- lapply(df, function(row) { row$age <- row$age + 5 @@ -1497,7 +1533,6 @@ test_that("column functions", { collect(select(df, alias(not(df$is_true), "is_false"))), data.frame(is_false = c(FALSE, TRUE, NA)) ) - }) test_that("column binary mathfunctions", { @@ -2306,6 +2341,8 @@ test_that("mutate(), transform(), rename() and names()", { }) test_that("read/write ORC files", { + skip_on_cran() + setHiveContext(sc) df <- read.df(jsonPath, "json") @@ -2327,6 +2364,8 @@ test_that("read/write ORC files", { }) test_that("read/write ORC files - compression option", { + skip_on_cran() + setHiveContext(sc) df <- read.df(jsonPath, "json") @@ -2373,6 +2412,8 @@ test_that("read/write Parquet files", { }) test_that("read/write Parquet files - compression option/mode", { + skip_on_cran() + df <- read.df(jsonPath, "json") tempPath <- tempfile(pattern = "tempPath", fileext = ".parquet") @@ -2390,6 +2431,8 @@ test_that("read/write Parquet files - compression option/mode", { }) test_that("read/write text files", { + skip_on_cran() + # Test write.df and read.df df <- read.df(jsonPath, "text") expect_is(df, "SparkDataFrame") @@ -2411,6 +2454,8 @@ test_that("read/write text files", { }) test_that("read/write text files - compression option", { + skip_on_cran() + df <- read.df(jsonPath, "text") textPath <- tempfile(pattern = "textPath", fileext = ".txt") @@ -2644,6 +2689,8 @@ test_that("approxQuantile() on a DataFrame", { }) test_that("SQL error message is returned from JVM", { + skip_on_cran() + retError <- tryCatch(sql("select * from blah"), error = function(e) e) expect_equal(grepl("Table or view not found", retError), TRUE) expect_equal(grepl("blah", retError), TRUE) @@ -2652,6 +2699,8 @@ test_that("SQL error message is returned from JVM", { irisDF <- suppressWarnings(createDataFrame(iris)) test_that("Method as.data.frame as a synonym for collect()", { + skip_on_cran() + expect_equal(as.data.frame(irisDF), collect(irisDF)) irisDF2 <- irisDF[irisDF$Species == "setosa", ] expect_equal(as.data.frame(irisDF2), collect(irisDF2)) @@ -3069,6 +3118,8 @@ test_that("Window functions on a DataFrame", { }) test_that("createDataFrame sqlContext parameter backward compatibility", { + skip_on_cran() + sqlContext <- suppressWarnings(sparkRSQL.init(sc)) a <- 1:3 b <- c("a", "b", "c") @@ -3148,6 +3199,8 @@ test_that("Setting and getting config on SparkSession, sparkR.conf(), sparkR.uiW }) test_that("enableHiveSupport on SparkSession", { + skip_on_cran() + setHiveContext(sc) unsetHiveContext() # if we are still here, it must be built with hive @@ -3163,6 +3216,8 @@ test_that("Spark version from SparkSession", { }) test_that("Call DataFrameWriter.save() API in Java without path and check argument types", { + skip_on_cran() + df <- read.df(jsonPath, "json") # This tests if the exception is thrown from JVM not from SparkR side. # It makes sure that we can omit path argument in write.df API and then it calls @@ -3189,6 +3244,8 @@ test_that("Call DataFrameWriter.save() API in Java without path and check argume }) test_that("Call DataFrameWriter.load() API in Java without path and check argument types", { + skip_on_cran() + # This tests if the exception is thrown from JVM not from SparkR side. # It makes sure that we can omit path argument in read.df API and then it calls # DataFrameWriter.load() without path. @@ -3313,6 +3370,8 @@ compare_list <- function(list1, list2) { # This should always be the **very last test** in this test file. test_that("No extra files are created in SPARK_HOME by starting session and making calls", { + skip_on_cran() + # Check that it is not creating any extra file. # Does not check the tempdir which would be cleaned up after. filesAfter <- list.files(path = sparkRDir, all.files = TRUE) diff --git a/R/pkg/inst/tests/testthat/test_streaming.R b/R/pkg/inst/tests/testthat/test_streaming.R index b125cb0591de..884399102430 100644 --- a/R/pkg/inst/tests/testthat/test_streaming.R +++ b/R/pkg/inst/tests/testthat/test_streaming.R @@ -47,6 +47,8 @@ schema <- structType(structField("name", "string"), structField("count", "double")) test_that("read.stream, write.stream, awaitTermination, stopQuery", { + skip_on_cran() + df <- read.stream("json", path = jsonDir, schema = schema, maxFilesPerTrigger = 1) expect_true(isStreaming(df)) counts <- count(group_by(df, "name")) @@ -65,6 +67,8 @@ test_that("read.stream, write.stream, awaitTermination, stopQuery", { }) test_that("print from explain, lastProgress, status, isActive", { + skip_on_cran() + df <- read.stream("json", path = jsonDir, schema = schema) expect_true(isStreaming(df)) counts <- count(group_by(df, "name")) @@ -83,6 +87,8 @@ test_that("print from explain, lastProgress, status, isActive", { }) test_that("Stream other format", { + skip_on_cran() + parquetPath <- tempfile(pattern = "sparkr-test", fileext = ".parquet") df <- read.df(jsonPath, "json", schema) write.df(df, parquetPath, "parquet", "overwrite") @@ -108,6 +114,8 @@ test_that("Stream other format", { }) test_that("Non-streaming DataFrame", { + skip_on_cran() + c <- as.DataFrame(cars) expect_false(isStreaming(c)) @@ -117,6 +125,8 @@ test_that("Non-streaming DataFrame", { }) test_that("Unsupported operation", { + skip_on_cran() + # memory sink without aggregation df <- read.stream("json", path = jsonDir, schema = schema, maxFilesPerTrigger = 1) expect_error(write.stream(df, "memory", queryName = "people", outputMode = "complete"), @@ -125,6 +135,8 @@ test_that("Unsupported operation", { }) test_that("Terminated by error", { + skip_on_cran() + df <- read.stream("json", path = jsonDir, schema = schema, maxFilesPerTrigger = -1) counts <- count(group_by(df, "name")) # This would not fail before returning with a StreamingQuery, diff --git a/R/pkg/inst/tests/testthat/test_take.R b/R/pkg/inst/tests/testthat/test_take.R index aaa532856c3d..e2130eaac78d 100644 --- a/R/pkg/inst/tests/testthat/test_take.R +++ b/R/pkg/inst/tests/testthat/test_take.R @@ -34,6 +34,8 @@ sparkSession <- sparkR.session(enableHiveSupport = FALSE) sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession) test_that("take() gives back the original elements in correct count and order", { + skip_on_cran() + numVectorRDD <- parallelize(sc, numVector, 10) # case: number of elements to take is less than the size of the first partition expect_equal(takeRDD(numVectorRDD, 1), as.list(head(numVector, n = 1))) diff --git a/R/pkg/inst/tests/testthat/test_textFile.R b/R/pkg/inst/tests/testthat/test_textFile.R index 3b466066e939..28b7e8e3183f 100644 --- a/R/pkg/inst/tests/testthat/test_textFile.R +++ b/R/pkg/inst/tests/testthat/test_textFile.R @@ -24,6 +24,8 @@ sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", mockFile <- c("Spark is pretty.", "Spark is awesome.") test_that("textFile() on a local file returns an RDD", { + skip_on_cran() + fileName <- tempfile(pattern = "spark-test", fileext = ".tmp") writeLines(mockFile, fileName) @@ -36,6 +38,8 @@ test_that("textFile() on a local file returns an RDD", { }) test_that("textFile() followed by a collect() returns the same content", { + skip_on_cran() + fileName <- tempfile(pattern = "spark-test", fileext = ".tmp") writeLines(mockFile, fileName) @@ -46,6 +50,8 @@ test_that("textFile() followed by a collect() returns the same content", { }) test_that("textFile() word count works as expected", { + skip_on_cran() + fileName <- tempfile(pattern = "spark-test", fileext = ".tmp") writeLines(mockFile, fileName) @@ -64,6 +70,8 @@ test_that("textFile() word count works as expected", { }) test_that("several transformations on RDD created by textFile()", { + skip_on_cran() + fileName <- tempfile(pattern = "spark-test", fileext = ".tmp") writeLines(mockFile, fileName) @@ -78,6 +86,8 @@ test_that("several transformations on RDD created by textFile()", { }) test_that("textFile() followed by a saveAsTextFile() returns the same content", { + skip_on_cran() + fileName1 <- tempfile(pattern = "spark-test", fileext = ".tmp") fileName2 <- tempfile(pattern = "spark-test", fileext = ".tmp") writeLines(mockFile, fileName1) @@ -92,6 +102,8 @@ test_that("textFile() followed by a saveAsTextFile() returns the same content", }) test_that("saveAsTextFile() on a parallelized list works as expected", { + skip_on_cran() + fileName <- tempfile(pattern = "spark-test", fileext = ".tmp") l <- list(1, 2, 3) rdd <- parallelize(sc, l, 1L) @@ -103,6 +115,8 @@ test_that("saveAsTextFile() on a parallelized list works as expected", { }) test_that("textFile() and saveAsTextFile() word count works as expected", { + skip_on_cran() + fileName1 <- tempfile(pattern = "spark-test", fileext = ".tmp") fileName2 <- tempfile(pattern = "spark-test", fileext = ".tmp") writeLines(mockFile, fileName1) @@ -128,6 +142,8 @@ test_that("textFile() and saveAsTextFile() word count works as expected", { }) test_that("textFile() on multiple paths", { + skip_on_cran() + fileName1 <- tempfile(pattern = "spark-test", fileext = ".tmp") fileName2 <- tempfile(pattern = "spark-test", fileext = ".tmp") writeLines("Spark is pretty.", fileName1) @@ -141,6 +157,8 @@ test_that("textFile() on multiple paths", { }) test_that("Pipelined operations on RDDs created using textFile", { + skip_on_cran() + fileName <- tempfile(pattern = "spark-test", fileext = ".tmp") writeLines(mockFile, fileName) diff --git a/R/pkg/inst/tests/testthat/test_utils.R b/R/pkg/inst/tests/testthat/test_utils.R index 1ca383da26ec..4a01e875405f 100644 --- a/R/pkg/inst/tests/testthat/test_utils.R +++ b/R/pkg/inst/tests/testthat/test_utils.R @@ -23,6 +23,7 @@ sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", test_that("convertJListToRList() gives back (deserializes) the original JLists of strings and integers", { + skip_on_cran() # It's hard to manually create a Java List using rJava, since it does not # support generics well. Instead, we rely on collectRDD() returning a # JList. @@ -40,6 +41,7 @@ test_that("convertJListToRList() gives back (deserializes) the original JLists }) test_that("serializeToBytes on RDD", { + skip_on_cran() # File content mockFile <- c("Spark is pretty.", "Spark is awesome.") fileName <- tempfile(pattern = "spark-test", fileext = ".tmp") @@ -167,6 +169,8 @@ test_that("convertToJSaveMode", { }) test_that("captureJVMException", { + skip_on_cran() + method <- "createStructField" expect_error(tryCatch(callJStatic("org.apache.spark.sql.api.r.SQLUtils", method, "col", "unknown", TRUE), @@ -177,6 +181,8 @@ test_that("captureJVMException", { }) test_that("hashCode", { + skip_on_cran() + expect_error(hashCode("bc53d3605e8a5b7de1e8e271c2317645"), NA) }) diff --git a/R/run-tests.sh b/R/run-tests.sh index 742a2c5ed76d..29764f48bd15 100755 --- a/R/run-tests.sh +++ b/R/run-tests.sh @@ -23,7 +23,7 @@ FAILED=0 LOGFILE=$FWDIR/unit-tests.out rm -f $LOGFILE -SPARK_TESTING=1 $FWDIR/../bin/spark-submit --driver-java-options "-Dlog4j.configuration=file:$FWDIR/log4j.properties" --conf spark.hadoop.fs.defaultFS="file:///" $FWDIR/pkg/tests/run-all.R 2>&1 | tee -a $LOGFILE +SPARK_TESTING=1 NOT_CRAN=true $FWDIR/../bin/spark-submit --driver-java-options "-Dlog4j.configuration=file:$FWDIR/log4j.properties" --conf spark.hadoop.fs.defaultFS="file:///" $FWDIR/pkg/tests/run-all.R 2>&1 | tee -a $LOGFILE FAILED=$((PIPESTATUS[0]||$FAILED)) NUM_TEST_WARNING="$(grep -c -e 'Warnings ----------------' $LOGFILE)" From b8302ccd02265f9d7a7895c7b033441fa2d8ffd1 Mon Sep 17 00:00:00 2001 From: Felix Cheung Date: Thu, 4 May 2017 00:27:10 -0700 Subject: [PATCH 419/512] [SPARK-20015][SPARKR][SS][DOC][EXAMPLE] Document R Structured Streaming (experimental) in R vignettes and R & SS programming guide, R example ## What changes were proposed in this pull request? Add - R vignettes - R programming guide - SS programming guide - R example Also disable spark.als in vignettes for now since it's failing (SPARK-20402) ## How was this patch tested? manually Author: Felix Cheung Closes #17814 from felixcheung/rdocss. --- R/pkg/vignettes/sparkr-vignettes.Rmd | 79 ++++- docs/sparkr.md | 4 + .../structured-streaming-programming-guide.md | 285 +++++++++++++++--- .../streaming/structured_network_wordcount.R | 57 ++++ 4 files changed, 381 insertions(+), 44 deletions(-) create mode 100644 examples/src/main/r/streaming/structured_network_wordcount.R diff --git a/R/pkg/vignettes/sparkr-vignettes.Rmd b/R/pkg/vignettes/sparkr-vignettes.Rmd index 4b9d6c380609..d38ec4f1b6f3 100644 --- a/R/pkg/vignettes/sparkr-vignettes.Rmd +++ b/R/pkg/vignettes/sparkr-vignettes.Rmd @@ -182,7 +182,7 @@ head(df) ``` ### Data Sources -SparkR supports operating on a variety of data sources through the `SparkDataFrame` interface. You can check the Spark SQL programming guide for more [specific options](https://spark.apache.org/docs/latest/sql-programming-guide.html#manually-specifying-options) that are available for the built-in data sources. +SparkR supports operating on a variety of data sources through the `SparkDataFrame` interface. You can check the Spark SQL Programming Guide for more [specific options](https://spark.apache.org/docs/latest/sql-programming-guide.html#manually-specifying-options) that are available for the built-in data sources. The general method for creating `SparkDataFrame` from data sources is `read.df`. This method takes in the path for the file to load and the type of data source, and the currently active Spark Session will be used automatically. SparkR supports reading CSV, JSON and Parquet files natively and through Spark Packages you can find data source connectors for popular file formats like Avro. These packages can be added with `sparkPackages` parameter when initializing SparkSession using `sparkR.session`. @@ -232,7 +232,7 @@ write.df(people, path = "people.parquet", source = "parquet", mode = "overwrite" ``` ### Hive Tables -You can also create SparkDataFrames from Hive tables. To do this we will need to create a SparkSession with Hive support which can access tables in the Hive MetaStore. Note that Spark should have been built with Hive support and more details can be found in the [SQL programming guide](https://spark.apache.org/docs/latest/sql-programming-guide.html). In SparkR, by default it will attempt to create a SparkSession with Hive support enabled (`enableHiveSupport = TRUE`). +You can also create SparkDataFrames from Hive tables. To do this we will need to create a SparkSession with Hive support which can access tables in the Hive MetaStore. Note that Spark should have been built with Hive support and more details can be found in the [SQL Programming Guide](https://spark.apache.org/docs/latest/sql-programming-guide.html). In SparkR, by default it will attempt to create a SparkSession with Hive support enabled (`enableHiveSupport = TRUE`). ```{r, eval=FALSE} sql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") @@ -314,7 +314,7 @@ Use `cube` or `rollup` to compute subtotals across multiple dimensions. mean(cube(carsDF, "cyl", "gear", "am"), "mpg") ``` -generates groupings for {(`cyl`, `gear`, `am`), (`cyl`, `gear`), (`cyl`), ()}, while +generates groupings for {(`cyl`, `gear`, `am`), (`cyl`, `gear`), (`cyl`), ()}, while ```{r} mean(rollup(carsDF, "cyl", "gear", "am"), "mpg") @@ -672,6 +672,7 @@ head(select(naiveBayesPrediction, "Class", "Sex", "Age", "Survived", "prediction Survival analysis studies the expected duration of time until an event happens, and often the relationship with risk factors or treatment taken on the subject. In contrast to standard regression analysis, survival modeling has to deal with special characteristics in the data including non-negative survival time and censoring. Accelerated Failure Time (AFT) model is a parametric survival model for censored data that assumes the effect of a covariate is to accelerate or decelerate the life course of an event by some constant. For more information, refer to the Wikipedia page [AFT Model](https://en.wikipedia.org/wiki/Accelerated_failure_time_model) and the references there. Different from a [Proportional Hazards Model](https://en.wikipedia.org/wiki/Proportional_hazards_model) designed for the same purpose, the AFT model is easier to parallelize because each instance contributes to the objective function independently. + ```{r, warning=FALSE} library(survival) ovarianDF <- createDataFrame(ovarian) @@ -902,7 +903,7 @@ perplexity There are multiple options that can be configured in `spark.als`, including `rank`, `reg`, `nonnegative`. For a complete list, refer to the help file. -```{r} +```{r, eval=FALSE} ratings <- list(list(0, 0, 4.0), list(0, 1, 2.0), list(1, 1, 3.0), list(1, 2, 4.0), list(2, 1, 1.0), list(2, 2, 5.0)) df <- createDataFrame(ratings, c("user", "item", "rating")) @@ -910,7 +911,7 @@ model <- spark.als(df, "rating", "user", "item", rank = 10, reg = 0.1, nonnegati ``` Extract latent factors. -```{r} +```{r, eval=FALSE} stats <- summary(model) userFactors <- stats$userFactors itemFactors <- stats$itemFactors @@ -920,7 +921,7 @@ head(itemFactors) Make predictions. -```{r} +```{r, eval=FALSE} predicted <- predict(model, df) head(predicted) ``` @@ -1002,6 +1003,72 @@ unlink(modelPath) ``` +## Structured Streaming + +SparkR supports the Structured Streaming API (experimental). + +You can check the Structured Streaming Programming Guide for [an introduction](https://spark.apache.org/docs/latest/structured-streaming-programming-guide.html#programming-model) to its programming model and basic concepts. + +### Simple Source and Sink + +Spark has a few built-in input sources. As an example, to test with a socket source reading text into words and displaying the computed word counts: + +```{r, eval=FALSE} +# Create DataFrame representing the stream of input lines from connection +lines <- read.stream("socket", host = hostname, port = port) + +# Split the lines into words +words <- selectExpr(lines, "explode(split(value, ' ')) as word") + +# Generate running word count +wordCounts <- count(groupBy(words, "word")) + +# Start running the query that prints the running counts to the console +query <- write.stream(wordCounts, "console", outputMode = "complete") +``` + +### Kafka Source + +It is simple to read data from Kafka. For more information, see [Input Sources](https://spark.apache.org/docs/latest/structured-streaming-programming-guide.html#input-sources) supported by Structured Streaming. + +```{r, eval=FALSE} +topic <- read.stream("kafka", + kafka.bootstrap.servers = "host1:port1,host2:port2", + subscribe = "topic1") +keyvalue <- selectExpr(topic, "CAST(key AS STRING)", "CAST(value AS STRING)") +``` + +### Operations and Sinks + +Most of the common operations on `SparkDataFrame` are supported for streaming, including selection, projection, and aggregation. Once you have defined the final result, to start the streaming computation, you will call the `write.stream` method setting a sink and `outputMode`. + +A streaming `SparkDataFrame` can be written for debugging to the console, to a temporary in-memory table, or for further processing in a fault-tolerant manner to a File Sink in different formats. + +```{r, eval=FALSE} +noAggDF <- select(where(deviceDataStreamingDf, "signal > 10"), "device") + +# Print new data to console +write.stream(noAggDF, "console") + +# Write new data to Parquet files +write.stream(noAggDF, + "parquet", + path = "path/to/destination/dir", + checkpointLocation = "path/to/checkpoint/dir") + +# Aggregate +aggDF <- count(groupBy(noAggDF, "device")) + +# Print updated aggregations to console +write.stream(aggDF, "console", outputMode = "complete") + +# Have all the aggregates in an in memory table. The query name will be the table name +write.stream(aggDF, "memory", queryName = "aggregates", outputMode = "complete") + +head(sql("select * from aggregates")) +``` + + ## Advanced Topics ### SparkR Object Classes diff --git a/docs/sparkr.md b/docs/sparkr.md index 6dbd02a48890..569b85e72c3c 100644 --- a/docs/sparkr.md +++ b/docs/sparkr.md @@ -593,6 +593,10 @@ The following example shows how to save/load a MLlib model by SparkR.

    (?i)secret|password Regex to decide which Spark configuration properties and environment variables in driver and - executor environments contain sensitive information. When this regex matches a property, its - value is redacted from the environment UI and various logs like YARN and event logs. + executor environments contain sensitive information. When this regex matches a property key or + value, the value is redacted from the environment UI and various logs like YARN and event logs.
    spark.worker.cleanup.appDataTtl7 * 24 * 3600 (7 days)604800 (7 days, 7 * 24 * 3600) The number of seconds to retain application work directories on each worker. This is a Time To Live and should depend on the amount of available disk space you have. Application logs and jars are From 1ee494d0868a85af3154996732817ed63679f382 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Sun, 30 Apr 2017 08:24:10 -0700 Subject: [PATCH 387/512] [SPARK-20492][SQL] Do not print empty parentheses for invalid primitive types in parser ## What changes were proposed in this pull request? Currently, when the type string is invalid, it looks printing empty parentheses. This PR proposes a small improvement in an error message by removing it in the parse as below: ```scala spark.range(1).select($"col".cast("aa")) ``` **Before** ``` org.apache.spark.sql.catalyst.parser.ParseException: DataType aa() is not supported.(line 1, pos 0) == SQL == aa ^^^ ``` **After** ``` org.apache.spark.sql.catalyst.parser.ParseException: DataType aa is not supported.(line 1, pos 0) == SQL == aa ^^^ ``` ## How was this patch tested? Unit tests in `DataTypeParserSuite`. Author: hyukjinkwon Closes #17784 from HyukjinKwon/SPARK-20492. --- .../org/apache/spark/sql/catalyst/parser/AstBuilder.scala | 4 ++-- .../spark/sql/catalyst/parser/DataTypeParserSuite.scala | 7 ++++++- .../resources/sql-tests/results/json-functions.sql.out | 2 +- .../scala/org/apache/spark/sql/JsonFunctionsSuite.scala | 2 +- 4 files changed, 10 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 2cf06d15664d..a48a693a95c9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -1491,8 +1491,8 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { case ("decimal", precision :: scale :: Nil) => DecimalType(precision.getText.toInt, scale.getText.toInt) case (dt, params) => - throw new ParseException( - s"DataType $dt${params.mkString("(", ",", ")")} is not supported.", ctx) + val dtStr = if (params.nonEmpty) s"$dt(${params.mkString(",")})" else dt + throw new ParseException(s"DataType $dtStr is not supported.", ctx) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserSuite.scala index 3964fa3924b2..449052336900 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserSuite.scala @@ -30,7 +30,7 @@ class DataTypeParserSuite extends SparkFunSuite { } } - def intercept(sql: String): Unit = + def intercept(sql: String): ParseException = intercept[ParseException](CatalystSqlParser.parseDataType(sql)) def unsupported(dataTypeString: String): Unit = { @@ -118,6 +118,11 @@ class DataTypeParserSuite extends SparkFunSuite { unsupported("struct") + test("Do not print empty parentheses for no params") { + assert(intercept("unkwon").getMessage.contains("unkwon is not supported")) + assert(intercept("unkwon(1,2,3)").getMessage.contains("unkwon(1,2,3) is not supported")) + } + // DataType parser accepts certain reserved keywords. checkDataType( "Struct", diff --git a/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out index 315e1730ce7d..fedabaee2237 100644 --- a/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out @@ -141,7 +141,7 @@ struct<> -- !query 13 output org.apache.spark.sql.AnalysisException -DataType invalidtype() is not supported.(line 1, pos 2) +DataType invalidtype is not supported.(line 1, pos 2) == SQL == a InvalidType diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala index 8465e8d036a6..69a500c845a7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala @@ -274,7 +274,7 @@ class JsonFunctionsSuite extends QueryTest with SharedSQLContext { val errMsg2 = intercept[AnalysisException] { df3.selectExpr("""from_json(value, 'time InvalidType')""") } - assert(errMsg2.getMessage.contains("DataType invalidtype() is not supported")) + assert(errMsg2.getMessage.contains("DataType invalidtype is not supported")) val errMsg3 = intercept[AnalysisException] { df3.selectExpr("from_json(value, 'time Timestamp', named_struct('a', 1))") } From ae3df4e98f160f94d1e52c90363f26eb351d0153 Mon Sep 17 00:00:00 2001 From: zero323 Date: Sun, 30 Apr 2017 12:33:03 -0700 Subject: [PATCH 388/512] [SPARK-20535][SPARKR] R wrappers for explode_outer and posexplode_outer ## What changes were proposed in this pull request? Ad R wrappers for - `o.a.s.sql.functions.explode_outer` - `o.a.s.sql.functions.posexplode_outer` ## How was this patch tested? Additional unit tests, manual testing. Author: zero323 Closes #17809 from zero323/SPARK-20535. --- R/pkg/NAMESPACE | 2 + R/pkg/R/functions.R | 56 +++++++++++++++++++++++ R/pkg/R/generics.R | 8 ++++ R/pkg/inst/tests/testthat/test_sparkSQL.R | 1 + 4 files changed, 67 insertions(+) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 280046165848..db8e06db18ed 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -234,6 +234,7 @@ exportMethods("%in%", "endsWith", "exp", "explode", + "explode_outer", "expm1", "expr", "factorial", @@ -296,6 +297,7 @@ exportMethods("%in%", "percent_rank", "pmod", "posexplode", + "posexplode_outer", "quarter", "rand", "randn", diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 6b91fa5bde67..f4a34fbabe4d 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -3803,3 +3803,59 @@ setMethod("repeat_string", jc <- callJStatic("org.apache.spark.sql.functions", "repeat", x@jc, numToInt(n)) column(jc) }) + +#' explode_outer +#' +#' Creates a new row for each element in the given array or map column. +#' Unlike \code{explode}, if the array/map is \code{null} or empty +#' then \code{null} is produced. +#' +#' @param x Column to compute on +#' +#' @rdname explode_outer +#' @name explode_outer +#' @family collection_funcs +#' @aliases explode_outer,Column-method +#' @export +#' @examples \dontrun{ +#' df <- createDataFrame(data.frame( +#' id = c(1, 2, 3), text = c("a,b,c", NA, "d,e") +#' )) +#' +#' head(select(df, df$id, explode_outer(split_string(df$text, ",")))) +#' } +#' @note explode_outer since 2.3.0 +setMethod("explode_outer", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "explode_outer", x@jc) + column(jc) + }) + +#' posexplode_outer +#' +#' Creates a new row for each element with position in the given array or map column. +#' Unlike \code{posexplode}, if the array/map is \code{null} or empty +#' then the row (\code{null}, \code{null}) is produced. +#' +#' @param x Column to compute on +#' +#' @rdname posexplode_outer +#' @name posexplode_outer +#' @family collection_funcs +#' @aliases posexplode_outer,Column-method +#' @export +#' @examples \dontrun{ +#' df <- createDataFrame(data.frame( +#' id = c(1, 2, 3), text = c("a,b,c", NA, "d,e") +#' )) +#' +#' head(select(df, df$id, posexplode_outer(split_string(df$text, ",")))) +#' } +#' @note posexplode_outer since 2.3.0 +setMethod("posexplode_outer", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "posexplode_outer", x@jc) + column(jc) + }) diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 749ee9b54cc8..e510ff9a2d80 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -1016,6 +1016,10 @@ setGeneric("encode", function(x, charset) { standardGeneric("encode") }) #' @export setGeneric("explode", function(x) { standardGeneric("explode") }) +#' @rdname explode_outer +#' @export +setGeneric("explode_outer", function(x) { standardGeneric("explode_outer") }) + #' @rdname expr #' @export setGeneric("expr", function(x) { standardGeneric("expr") }) @@ -1175,6 +1179,10 @@ setGeneric("pmod", function(y, x) { standardGeneric("pmod") }) #' @export setGeneric("posexplode", function(x) { standardGeneric("posexplode") }) +#' @rdname posexplode_outer +#' @export +setGeneric("posexplode_outer", function(x) { standardGeneric("posexplode_outer") }) + #' @rdname quarter #' @export setGeneric("quarter", function(x) { standardGeneric("quarter") }) diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index 1a3d6df437d7..1828cddffd27 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -1347,6 +1347,7 @@ test_that("column functions", { c18 <- covar_pop(c, c1) + covar_pop("c", "c1") c19 <- spark_partition_id() + coalesce(c) + coalesce(c1, c2, c3) c20 <- to_timestamp(c) + to_timestamp(c, "yyyy") + to_date(c, "yyyy") + c21 <- posexplode_outer(c) + explode_outer(c) # Test if base::is.nan() is exposed expect_equal(is.nan(c("a", "b")), c(FALSE, FALSE)) From 6613046c8c2daaf46a8ec13dd0a016aad22af1a4 Mon Sep 17 00:00:00 2001 From: Srinivasa Reddy Vundela Date: Sun, 30 Apr 2017 21:42:05 -0700 Subject: [PATCH 389/512] [MINOR][DOCS][PYTHON] Adding missing boolean type for replacement value in fillna ## What changes were proposed in this pull request? Currently pyspark Dataframe.fillna API supports boolean type when we pass dict, but it is missing in documentation. ## How was this patch tested? >>> spark.createDataFrame([Row(a=True),Row(a=None)]).fillna({"a" : True}).show() +----+ | a| +----+ |true| |true| +----+ Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Srinivasa Reddy Vundela Closes #17688 from vundela/fillna_doc_fix. --- python/pyspark/sql/dataframe.py | 2 +- python/pyspark/sql/tests.py | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index ff21bb5d2fb3..ab6d35bfa7c5 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1247,7 +1247,7 @@ def fillna(self, value, subset=None): Value to replace null values with. If the value is a dict, then `subset` is ignored and `value` must be a mapping from column name (string) to replacement value. The replacement value must be - an int, long, float, or string. + an int, long, float, boolean, or string. :param subset: optional list of column names to consider. Columns specified in subset that do not have matching data type are ignored. For example, if `value` is a string, and subset contains a non-string column, diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 2b2444304e04..cd92148dfa5d 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -1711,6 +1711,10 @@ def test_fillna(self): self.assertEqual(row.age, None) self.assertEqual(row.height, None) + # fillna with dictionary for boolean types + row = self.spark.createDataFrame([Row(a=None), Row(a=True)]).fillna({"a": True}).first() + self.assertEqual(row.a, True) + def test_bitwise_operations(self): from pyspark.sql import functions row = Row(a=170, b=75) From 80e9cf1b59ce7186a4506f83e50f4fc7759c938c Mon Sep 17 00:00:00 2001 From: zero323 Date: Sun, 30 Apr 2017 22:07:12 -0700 Subject: [PATCH 390/512] [SPARK-20490][SPARKR] Add R wrappers for eqNullSafe and ! / not ## What changes were proposed in this pull request? - Add null-safe equality operator `%<=>%` (sames as `o.a.s.sql.Column.eqNullSafe`, `o.a.s.sql.Column.<=>`) - Add boolean negation operator `!` and function `not `. ## How was this patch tested? Existing unit tests, additional unit tests, `check-cran.sh`. Author: zero323 Closes #17783 from zero323/SPARK-20490. --- R/pkg/NAMESPACE | 4 +- R/pkg/R/column.R | 55 ++++++++++++++++++++++- R/pkg/R/functions.R | 31 +++++++++++++ R/pkg/R/generics.R | 8 ++++ R/pkg/inst/tests/testthat/test_context.R | 4 +- R/pkg/inst/tests/testthat/test_sparkSQL.R | 20 +++++++++ 6 files changed, 117 insertions(+), 5 deletions(-) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index db8e06db18ed..e8de34d9371a 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -182,7 +182,8 @@ exportMethods("arrange", exportClasses("Column") -exportMethods("%in%", +exportMethods("%<=>%", + "%in%", "abs", "acos", "add_months", @@ -291,6 +292,7 @@ exportMethods("%in%", "nanvl", "negate", "next_day", + "not", "ntile", "otherwise", "over", diff --git a/R/pkg/R/column.R b/R/pkg/R/column.R index 539d91b0f879..147ee4b6887b 100644 --- a/R/pkg/R/column.R +++ b/R/pkg/R/column.R @@ -67,8 +67,7 @@ operators <- list( "+" = "plus", "-" = "minus", "*" = "multiply", "/" = "divide", "%%" = "mod", "==" = "equalTo", ">" = "gt", "<" = "lt", "!=" = "notEqual", "<=" = "leq", ">=" = "geq", # we can not override `&&` and `||`, so use `&` and `|` instead - "&" = "and", "|" = "or", #, "!" = "unary_$bang" - "^" = "pow" + "&" = "and", "|" = "or", "^" = "pow" ) column_functions1 <- c("asc", "desc", "isNaN", "isNull", "isNotNull") column_functions2 <- c("like", "rlike", "getField", "getItem", "contains") @@ -302,3 +301,55 @@ setMethod("otherwise", jc <- callJMethod(x@jc, "otherwise", value) column(jc) }) + +#' \%<=>\% +#' +#' Equality test that is safe for null values. +#' +#' Can be used, unlike standard equality operator, to perform null-safe joins. +#' Equivalent to Scala \code{Column.<=>} and \code{Column.eqNullSafe}. +#' +#' @param x a Column +#' @param value a value to compare +#' @rdname eq_null_safe +#' @name %<=>% +#' @aliases %<=>%,Column-method +#' @export +#' @examples +#' \dontrun{ +#' df1 <- createDataFrame(data.frame( +#' x = c(1, NA, 3, NA), y = c(2, 6, 3, NA) +#' )) +#' +#' head(select(df1, df1$x == df1$y, df1$x %<=>% df1$y)) +#' +#' df2 <- createDataFrame(data.frame(y = c(3, NA))) +#' count(join(df1, df2, df1$y == df2$y)) +#' +#' count(join(df1, df2, df1$y %<=>% df2$y)) +#' } +#' @note \%<=>\% since 2.3.0 +setMethod("%<=>%", + signature(x = "Column", value = "ANY"), + function(x, value) { + value <- if (class(value) == "Column") { value@jc } else { value } + jc <- callJMethod(x@jc, "eqNullSafe", value) + column(jc) + }) + +#' ! +#' +#' Inversion of boolean expression. +#' +#' @rdname not +#' @name not +#' @aliases !,Column-method +#' @export +#' @examples +#' \dontrun{ +#' df <- createDataFrame(data.frame(x = c(-1, 0, 1))) +#' +#' head(select(df, !column("x") > 0)) +#' } +#' @note ! since 2.3.0 +setMethod("!", signature(x = "Column"), function(x) not(x)) diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index f4a34fbabe4d..f9687d680e7a 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -3859,3 +3859,34 @@ setMethod("posexplode_outer", jc <- callJStatic("org.apache.spark.sql.functions", "posexplode_outer", x@jc) column(jc) }) + +#' not +#' +#' Inversion of boolean expression. +#' +#' \code{not} and \code{!} cannot be applied directly to numerical column. +#' To achieve R-like truthiness column has to be casted to \code{BooleanType}. +#' +#' @param x Column to compute on +#' @rdname not +#' @name not +#' @aliases not,Column-method +#' @export +#' @examples \dontrun{ +#' df <- createDataFrame(data.frame( +#' is_true = c(TRUE, FALSE, NA), +#' flag = c(1, 0, 1) +#' )) +#' +#' head(select(df, not(df$is_true))) +#' +#' # Explicit cast is required when working with numeric column +#' head(select(df, not(cast(df$flag, "boolean")))) +#' } +#' @note not since 2.3.0 +setMethod("not", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "not", x@jc) + column(jc) + }) diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index e510ff9a2d80..d4e4958dc078 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -856,6 +856,10 @@ setGeneric("otherwise", function(x, value) { standardGeneric("otherwise") }) #' @export setGeneric("over", function(x, window) { standardGeneric("over") }) +#' @rdname eq_null_safe +#' @export +setGeneric("%<=>%", function(x, value) { standardGeneric("%<=>%") }) + ###################### WindowSpec Methods ########################## #' @rdname partitionBy @@ -1154,6 +1158,10 @@ setGeneric("nanvl", function(y, x) { standardGeneric("nanvl") }) #' @export setGeneric("negate", function(x) { standardGeneric("negate") }) +#' @rdname not +#' @export +setGeneric("not", function(x) { standardGeneric("not") }) + #' @rdname next_day #' @export setGeneric("next_day", function(y, x) { standardGeneric("next_day") }) diff --git a/R/pkg/inst/tests/testthat/test_context.R b/R/pkg/inst/tests/testthat/test_context.R index c84711349111..c64fe6edcd49 100644 --- a/R/pkg/inst/tests/testthat/test_context.R +++ b/R/pkg/inst/tests/testthat/test_context.R @@ -21,10 +21,10 @@ test_that("Check masked functions", { # Check that we are not masking any new function from base, stats, testthat unexpectedly # NOTE: We should avoid adding entries to *namesOfMaskedCompletely* as masked functions make it # hard for users to use base R functions. Please check when in doubt. - namesOfMaskedCompletely <- c("cov", "filter", "sample") + namesOfMaskedCompletely <- c("cov", "filter", "sample", "not") namesOfMasked <- c("describe", "cov", "filter", "lag", "na.omit", "predict", "sd", "var", "colnames", "colnames<-", "intersect", "rank", "rbind", "sample", "subset", - "summary", "transform", "drop", "window", "as.data.frame", "union") + "summary", "transform", "drop", "window", "as.data.frame", "union", "not") if (as.numeric(R.version$major) >= 3 && as.numeric(R.version$minor) >= 3) { namesOfMasked <- c("endsWith", "startsWith", namesOfMasked) } diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index 1828cddffd27..08296354ca7e 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -1323,6 +1323,8 @@ test_that("column operators", { c3 <- (c + c2 - c2) * c2 %% c2 c4 <- (c > c2) & (c2 <= c3) | (c == c2) & (c2 != c3) c5 <- c2 ^ c3 ^ c4 + c6 <- c2 %<=>% c3 + c7 <- !c6 }) test_that("column functions", { @@ -1348,6 +1350,7 @@ test_that("column functions", { c19 <- spark_partition_id() + coalesce(c) + coalesce(c1, c2, c3) c20 <- to_timestamp(c) + to_timestamp(c, "yyyy") + to_date(c, "yyyy") c21 <- posexplode_outer(c) + explode_outer(c) + c22 <- not(c) # Test if base::is.nan() is exposed expect_equal(is.nan(c("a", "b")), c(FALSE, FALSE)) @@ -1488,6 +1491,13 @@ test_that("column functions", { lapply( list(list(x = 1, y = -1, z = -2), list(x = 2, y = 3, z = 5)), as.environment)) + + df <- as.DataFrame(data.frame(is_true = c(TRUE, FALSE, NA))) + expect_equal( + collect(select(df, alias(not(df$is_true), "is_false"))), + data.frame(is_false = c(FALSE, TRUE, NA)) + ) + }) test_that("column binary mathfunctions", { @@ -1973,6 +1983,16 @@ test_that("filter() on a DataFrame", { filtered6 <- where(df, df$age %in% c(19, 30)) expect_equal(count(filtered6), 2) + # test suites for %<=>% + dfNa <- read.json(jsonPathNa) + expect_equal(count(filter(dfNa, dfNa$age %<=>% 60)), 1) + expect_equal(count(filter(dfNa, !(dfNa$age %<=>% 60))), 5 - 1) + expect_equal(count(filter(dfNa, dfNa$age %<=>% NULL)), 3) + expect_equal(count(filter(dfNa, !(dfNa$age %<=>% NULL))), 5 - 3) + # match NA from two columns + expect_equal(count(filter(dfNa, dfNa$age %<=>% dfNa$height)), 2) + expect_equal(count(filter(dfNa, !(dfNa$age %<=>% dfNa$height))), 5 - 2) + # Test stats::filter is working #expect_true(is.ts(filter(1:100, rep(1, 3)))) # nolint }) From a355b667a3718d9c5d48a0781e836bf5418ab842 Mon Sep 17 00:00:00 2001 From: Felix Cheung Date: Sun, 30 Apr 2017 23:23:49 -0700 Subject: [PATCH 391/512] [SPARK-20541][SPARKR][SS] support awaitTermination without timeout ## What changes were proposed in this pull request? Add without param for timeout - will need this to submit a job that runs until stopped Need this for 2.2 ## How was this patch tested? manually, unit test Author: Felix Cheung Closes #17815 from felixcheung/rssawaitinfinite. --- R/pkg/R/generics.R | 2 +- R/pkg/R/streaming.R | 14 ++++++++++---- R/pkg/inst/tests/testthat/test_streaming.R | 1 + 3 files changed, 12 insertions(+), 5 deletions(-) diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index d4e4958dc078..ef36765a7a72 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -1518,7 +1518,7 @@ setGeneric("write.ml", function(object, path, ...) { standardGeneric("write.ml") #' @rdname awaitTermination #' @export -setGeneric("awaitTermination", function(x, timeout) { standardGeneric("awaitTermination") }) +setGeneric("awaitTermination", function(x, timeout = NULL) { standardGeneric("awaitTermination") }) #' @rdname isActive #' @export diff --git a/R/pkg/R/streaming.R b/R/pkg/R/streaming.R index e353d2dd07c3..8390bd5e6de7 100644 --- a/R/pkg/R/streaming.R +++ b/R/pkg/R/streaming.R @@ -169,8 +169,10 @@ setMethod("isActive", #' immediately. #' #' @param x a StreamingQuery. -#' @param timeout time to wait in milliseconds -#' @return TRUE if query has terminated within the timeout period. +#' @param timeout time to wait in milliseconds, if omitted, wait indefinitely until \code{stopQuery} +#' is called or an error has occured. +#' @return TRUE if query has terminated within the timeout period; nothing if timeout is not +#' specified. #' @rdname awaitTermination #' @name awaitTermination #' @aliases awaitTermination,StreamingQuery-method @@ -182,8 +184,12 @@ setMethod("isActive", #' @note experimental setMethod("awaitTermination", signature(x = "StreamingQuery"), - function(x, timeout) { - handledCallJMethod(x@ssq, "awaitTermination", as.integer(timeout)) + function(x, timeout = NULL) { + if (is.null(timeout)) { + invisible(handledCallJMethod(x@ssq, "awaitTermination")) + } else { + handledCallJMethod(x@ssq, "awaitTermination", as.integer(timeout)) + } }) #' stopQuery diff --git a/R/pkg/inst/tests/testthat/test_streaming.R b/R/pkg/inst/tests/testthat/test_streaming.R index 1f4054a84df5..b125cb0591de 100644 --- a/R/pkg/inst/tests/testthat/test_streaming.R +++ b/R/pkg/inst/tests/testthat/test_streaming.R @@ -61,6 +61,7 @@ test_that("read.stream, write.stream, awaitTermination, stopQuery", { stopQuery(q) expect_true(awaitTermination(q, 1)) + expect_error(awaitTermination(q), NA) }) test_that("print from explain, lastProgress, status, isActive", { From f0169a1c6a1ac06045d57f8aaa2c841bb39e23ac Mon Sep 17 00:00:00 2001 From: zero323 Date: Mon, 1 May 2017 09:43:32 -0700 Subject: [PATCH 392/512] [SPARK-20290][MINOR][PYTHON][SQL] Add PySpark wrapper for eqNullSafe ## What changes were proposed in this pull request? Adds Python bindings for `Column.eqNullSafe` ## How was this patch tested? Manual tests, existing unit tests, doc build. Author: zero323 Closes #17605 from zero323/SPARK-20290. --- python/pyspark/sql/column.py | 55 ++++++++++++++++++++++++++++++++++++ python/pyspark/sql/tests.py | 2 +- 2 files changed, 56 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py index b8df37f25180..e753ed402cdd 100644 --- a/python/pyspark/sql/column.py +++ b/python/pyspark/sql/column.py @@ -171,6 +171,61 @@ def __init__(self, jc): __ge__ = _bin_op("geq") __gt__ = _bin_op("gt") + _eqNullSafe_doc = """ + Equality test that is safe for null values. + + :param other: a value or :class:`Column` + + >>> from pyspark.sql import Row + >>> df1 = spark.createDataFrame([ + ... Row(id=1, value='foo'), + ... Row(id=2, value=None) + ... ]) + >>> df1.select( + ... df1['value'] == 'foo', + ... df1['value'].eqNullSafe('foo'), + ... df1['value'].eqNullSafe(None) + ... ).show() + +-------------+---------------+----------------+ + |(value = foo)|(value <=> foo)|(value <=> NULL)| + +-------------+---------------+----------------+ + | true| true| false| + | null| false| true| + +-------------+---------------+----------------+ + >>> df2 = spark.createDataFrame([ + ... Row(value = 'bar'), + ... Row(value = None) + ... ]) + >>> df1.join(df2, df1["value"] == df2["value"]).count() + 0 + >>> df1.join(df2, df1["value"].eqNullSafe(df2["value"])).count() + 1 + >>> df2 = spark.createDataFrame([ + ... Row(id=1, value=float('NaN')), + ... Row(id=2, value=42.0), + ... Row(id=3, value=None) + ... ]) + >>> df2.select( + ... df2['value'].eqNullSafe(None), + ... df2['value'].eqNullSafe(float('NaN')), + ... df2['value'].eqNullSafe(42.0) + ... ).show() + +----------------+---------------+----------------+ + |(value <=> NULL)|(value <=> NaN)|(value <=> 42.0)| + +----------------+---------------+----------------+ + | false| true| false| + | false| false| true| + | true| false| false| + +----------------+---------------+----------------+ + + .. note:: Unlike Pandas, PySpark doesn't consider NaN values to be NULL. + See the `NaN Semantics`_ for details. + .. _NaN Semantics: + https://spark.apache.org/docs/latest/sql-programming-guide.html#nan-semantics + .. versionadded:: 2.3.0 + """ + eqNullSafe = _bin_op("eqNullSafe", _eqNullSafe_doc) + # `and`, `or`, `not` cannot be overloaded in Python, # so use bitwise operators as boolean operators __and__ = _bin_op('and') diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index cd92148dfa5d..ce4abf8fb7e5 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -982,7 +982,7 @@ def test_column_operators(self): cbool = (ci & ci), (ci | ci), (~ci) self.assertTrue(all(isinstance(c, Column) for c in cbool)) css = cs.contains('a'), cs.like('a'), cs.rlike('a'), cs.asc(), cs.desc(),\ - cs.startswith('a'), cs.endswith('a') + cs.startswith('a'), cs.endswith('a'), ci.eqNullSafe(cs) self.assertTrue(all(isinstance(c, Column) for c in css)) self.assertTrue(isinstance(ci.cast(LongType()), Column)) self.assertRaisesRegexp(ValueError, From 6b44c4d63ab14162e338c5f1ac77333956870a90 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Mon, 1 May 2017 09:46:35 -0700 Subject: [PATCH 393/512] [SPARK-20534][SQL] Make outer generate exec return empty rows ## What changes were proposed in this pull request? Generate exec does not produce `null` values if the generator for the input row is empty and the generate operates in outer mode without join. This is caused by the fact that the `join=false` code path is different from the `join=true` code path, and that the `join=false` code path did deal with outer properly. This PR addresses this issue. ## How was this patch tested? Updated `outer*` tests in `GeneratorFunctionSuite`. Author: Herman van Hovell Closes #17810 from hvanhovell/SPARK-20534. --- .../sql/catalyst/optimizer/Optimizer.scala | 3 +- .../plans/logical/basicLogicalOperators.scala | 2 +- .../spark/sql/execution/GenerateExec.scala | 33 ++++++++++--------- .../spark/sql/GeneratorFunctionSuite.scala | 12 +++---- 4 files changed, 26 insertions(+), 24 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index dd768d18e858..f2b9764b0f08 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -441,8 +441,7 @@ object ColumnPruning extends Rule[LogicalPlan] { g.copy(child = prunedChild(g.child, g.references)) // Turn off `join` for Generate if no column from it's child is used - case p @ Project(_, g: Generate) - if g.join && !g.outer && p.references.subsetOf(g.generatedSet) => + case p @ Project(_, g: Generate) if g.join && p.references.subsetOf(g.generatedSet) => p.copy(child = g.copy(join = false)) // Eliminate unneeded attributes from right side of a Left Existence Join. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 3ad757ebba85..f663d7b8a8f7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -83,7 +83,7 @@ case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extend * @param join when true, each output row is implicitly joined with the input tuple that produced * it. * @param outer when true, each input row will be output at least once, even if the output of the - * given `generator` is empty. `outer` has no effect when `join` is false. + * given `generator` is empty. * @param qualifier Qualifier for the attributes of generator(UDTF) * @param generatorOutput The output schema of the Generator. * @param child Children logical plan node diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala index f87d05884b27..1812a1152cb4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructType} private[execution] sealed case class LazyIterator(func: () => TraversableOnce[InternalRow]) extends Iterator[InternalRow] { - lazy val results = func().toIterator + lazy val results: Iterator[InternalRow] = func().toIterator override def hasNext: Boolean = results.hasNext override def next(): InternalRow = results.next() } @@ -50,7 +50,7 @@ private[execution] sealed case class LazyIterator(func: () => TraversableOnce[In * @param join when true, each output row is implicitly joined with the input tuple that produced * it. * @param outer when true, each input row will be output at least once, even if the output of the - * given `generator` is empty. `outer` has no effect when `join` is false. + * given `generator` is empty. * @param generatorOutput the qualified output attributes of the generator of this node, which * constructed in analysis phase, and we can not change it, as the * parent node bound with it already. @@ -78,15 +78,15 @@ case class GenerateExec( override def outputPartitioning: Partitioning = child.outputPartitioning - val boundGenerator = BindReferences.bindReference(generator, child.output) + val boundGenerator: Generator = BindReferences.bindReference(generator, child.output) protected override def doExecute(): RDD[InternalRow] = { // boundGenerator.terminate() should be triggered after all of the rows in the partition - val rows = if (join) { - child.execute().mapPartitionsInternal { iter => - val generatorNullRow = new GenericInternalRow(generator.elementSchema.length) + val numOutputRows = longMetric("numOutputRows") + child.execute().mapPartitionsWithIndexInternal { (index, iter) => + val generatorNullRow = new GenericInternalRow(generator.elementSchema.length) + val rows = if (join) { val joinedRow = new JoinedRow - iter.flatMap { row => // we should always set the left (child output) joinedRow.withLeft(row) @@ -101,18 +101,21 @@ case class GenerateExec( // keep it the same as Hive does joinedRow.withRight(row) } + } else { + iter.flatMap { row => + val outputRows = boundGenerator.eval(row) + if (outer && outputRows.isEmpty) { + Seq(generatorNullRow) + } else { + outputRows + } + } ++ LazyIterator(boundGenerator.terminate) } - } else { - child.execute().mapPartitionsInternal { iter => - iter.flatMap(boundGenerator.eval) ++ LazyIterator(boundGenerator.terminate) - } - } - val numOutputRows = longMetric("numOutputRows") - rows.mapPartitionsWithIndexInternal { (index, iter) => + // Convert the rows to unsafe rows. val proj = UnsafeProjection.create(output, output) proj.initialize(index) - iter.map { r => + rows.map { r => numOutputRows += 1 proj(r) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala index cef5bbf0e85a..b9871afd59e4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala @@ -91,7 +91,7 @@ class GeneratorFunctionSuite extends QueryTest with SharedSQLContext { val df = Seq((1, Seq(1, 2, 3)), (2, Seq())).toDF("a", "intList") checkAnswer( df.select(explode_outer('intList)), - Row(1) :: Row(2) :: Row(3) :: Nil) + Row(1) :: Row(2) :: Row(3) :: Row(null) :: Nil) } test("single posexplode") { @@ -105,7 +105,7 @@ class GeneratorFunctionSuite extends QueryTest with SharedSQLContext { val df = Seq((1, Seq(1, 2, 3)), (2, Seq())).toDF("a", "intList") checkAnswer( df.select(posexplode_outer('intList)), - Row(0, 1) :: Row(1, 2) :: Row(2, 3) :: Nil) + Row(0, 1) :: Row(1, 2) :: Row(2, 3) :: Row(null, null) :: Nil) } test("explode and other columns") { @@ -161,7 +161,7 @@ class GeneratorFunctionSuite extends QueryTest with SharedSQLContext { checkAnswer( df.select(explode_outer('intList).as('int)).select('int), - Row(1) :: Row(2) :: Row(3) :: Nil) + Row(1) :: Row(2) :: Row(3) :: Row(null) :: Nil) checkAnswer( df.select(explode('intList).as('int)).select(sum('int)), @@ -182,7 +182,7 @@ class GeneratorFunctionSuite extends QueryTest with SharedSQLContext { checkAnswer( df.select(explode_outer('map)), - Row("a", "b") :: Row("c", "d") :: Nil) + Row("a", "b") :: Row(null, null) :: Row("c", "d") :: Nil) } test("explode on map with aliases") { @@ -198,7 +198,7 @@ class GeneratorFunctionSuite extends QueryTest with SharedSQLContext { checkAnswer( df.select(explode_outer('map).as("key1" :: "value1" :: Nil)).select("key1", "value1"), - Row("a", "b") :: Nil) + Row("a", "b") :: Row(null, null) :: Nil) } test("self join explode") { @@ -279,7 +279,7 @@ class GeneratorFunctionSuite extends QueryTest with SharedSQLContext { ) checkAnswer( df2.selectExpr("inline_outer(col1)"), - Row(3, "4") :: Row(5, "6") :: Nil + Row(null, null) :: Row(3, "4") :: Row(5, "6") :: Nil ) } From ab30590f448d05fc1864c54a59b6815bdeef8fc7 Mon Sep 17 00:00:00 2001 From: jerryshao Date: Mon, 1 May 2017 10:25:29 -0700 Subject: [PATCH 394/512] [SPARK-20517][UI] Fix broken history UI download link The download link in history server UI is concatenated with: ``` Download{{duration}} {{sparkUser}} {{lastUpdated}}DownloadDownload
    +# Structured Streaming + +SparkR supports the Structured Streaming API (experimental). Structured Streaming is a scalable and fault-tolerant stream processing engine built on the Spark SQL engine. For more information see the R API on the [Structured Streaming Programming Guide](structured-streaming-programming-guide.html) + # R Function Name Conflicts When loading and attaching a new package in R, it is possible to have a name [conflict](https://stat.ethz.ch/R-manual/R-devel/library/base/html/library.html), where a diff --git a/docs/structured-streaming-programming-guide.md b/docs/structured-streaming-programming-guide.md index 5b18cf2f3c2e..53b3db21da76 100644 --- a/docs/structured-streaming-programming-guide.md +++ b/docs/structured-streaming-programming-guide.md @@ -8,13 +8,13 @@ title: Structured Streaming Programming Guide {:toc} # Overview -Structured Streaming is a scalable and fault-tolerant stream processing engine built on the Spark SQL engine. You can express your streaming computation the same way you would express a batch computation on static data. The Spark SQL engine will take care of running it incrementally and continuously and updating the final result as streaming data continues to arrive. You can use the [Dataset/DataFrame API](sql-programming-guide.html) in Scala, Java or Python to express streaming aggregations, event-time windows, stream-to-batch joins, etc. The computation is executed on the same optimized Spark SQL engine. Finally, the system ensures end-to-end exactly-once fault-tolerance guarantees through checkpointing and Write Ahead Logs. In short, *Structured Streaming provides fast, scalable, fault-tolerant, end-to-end exactly-once stream processing without the user having to reason about streaming.* +Structured Streaming is a scalable and fault-tolerant stream processing engine built on the Spark SQL engine. You can express your streaming computation the same way you would express a batch computation on static data. The Spark SQL engine will take care of running it incrementally and continuously and updating the final result as streaming data continues to arrive. You can use the [Dataset/DataFrame API](sql-programming-guide.html) in Scala, Java, Python or R to express streaming aggregations, event-time windows, stream-to-batch joins, etc. The computation is executed on the same optimized Spark SQL engine. Finally, the system ensures end-to-end exactly-once fault-tolerance guarantees through checkpointing and Write Ahead Logs. In short, *Structured Streaming provides fast, scalable, fault-tolerant, end-to-end exactly-once stream processing without the user having to reason about streaming.* -**Structured Streaming is still ALPHA in Spark 2.1** and the APIs are still experimental. In this guide, we are going to walk you through the programming model and the APIs. First, let's start with a simple example - a streaming word count. +**Structured Streaming is still ALPHA in Spark 2.1** and the APIs are still experimental. In this guide, we are going to walk you through the programming model and the APIs. First, let's start with a simple example - a streaming word count. # Quick Example -Let’s say you want to maintain a running word count of text data received from a data server listening on a TCP socket. Let’s see how you can express this using Structured Streaming. You can see the full code in -[Scala]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/scala/org/apache/spark/examples/sql/streaming/StructuredNetworkWordCount.scala)/[Java]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/java/org/apache/spark/examples/sql/streaming/JavaStructuredNetworkWordCount.java)/[Python]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/python/sql/streaming/structured_network_wordcount.py). +Let’s say you want to maintain a running word count of text data received from a data server listening on a TCP socket. Let’s see how you can express this using Structured Streaming. You can see the full code in +[Scala]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/scala/org/apache/spark/examples/sql/streaming/StructuredNetworkWordCount.scala)/[Java]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/java/org/apache/spark/examples/sql/streaming/JavaStructuredNetworkWordCount.java)/[Python]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/python/sql/streaming/structured_network_wordcount.py)/[R]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/r/streaming/structured_network_wordcount.R). And if you [download Spark](http://spark.apache.org/downloads.html), you can directly run the example. In any case, let’s walk through the example step-by-step and understand how it works. First, we have to import the necessary classes and create a local SparkSession, the starting point of all functionalities related to Spark.
    @@ -63,6 +63,13 @@ spark = SparkSession \ .getOrCreate() {% endhighlight %} +
    +
    + +{% highlight r %} +sparkR.session(appName = "StructuredNetworkWordCount") +{% endhighlight %} +
    @@ -136,6 +143,22 @@ wordCounts = words.groupBy("word").count() This `lines` DataFrame represents an unbounded table containing the streaming text data. This table contains one column of strings named "value", and each line in the streaming text data becomes a row in the table. Note, that this is not currently receiving any data as we are just setting up the transformation, and have not yet started it. Next, we have used two built-in SQL functions - split and explode, to split each line into multiple rows with a word each. In addition, we use the function `alias` to name the new column as "word". Finally, we have defined the `wordCounts` DataFrame by grouping by the unique values in the Dataset and counting them. Note that this is a streaming DataFrame which represents the running word counts of the stream. + +
    + +{% highlight r %} +# Create DataFrame representing the stream of input lines from connection to localhost:9999 +lines <- read.stream("socket", host = "localhost", port = 9999) + +# Split the lines into words +words <- selectExpr(lines, "explode(split(value, ' ')) as word") + +# Generate running word count +wordCounts <- count(group_by(words, "word")) +{% endhighlight %} + +This `lines` SparkDataFrame represents an unbounded table containing the streaming text data. This table contains one column of strings named "value", and each line in the streaming text data becomes a row in the table. Note, that this is not currently receiving any data as we are just setting up the transformation, and have not yet started it. Next, we have a SQL expression with two SQL functions - split and explode, to split each line into multiple rows with a word each. In addition, we name the new column as "word". Finally, we have defined the `wordCounts` SparkDataFrame by grouping by the unique values in the SparkDataFrame and counting them. Note that this is a streaming SparkDataFrame which represents the running word counts of the stream. +
    @@ -181,10 +204,20 @@ query = wordCounts \ query.awaitTermination() {% endhighlight %} + +
    + +{% highlight r %} +# Start running the query that prints the running counts to the console +query <- write.stream(wordCounts, "console", outputMode = "complete") + +awaitTermination(query) +{% endhighlight %} +
    -After this code is executed, the streaming computation will have started in the background. The `query` object is a handle to that active streaming query, and we have decided to wait for the termination of the query using `query.awaitTermination()` to prevent the process from exiting while the query is active. +After this code is executed, the streaming computation will have started in the background. The `query` object is a handle to that active streaming query, and we have decided to wait for the termination of the query using `awaitTermination()` to prevent the process from exiting while the query is active. To actually execute this example code, you can either compile the code in your own [Spark application](quick-start.html#self-contained-applications), or simply @@ -211,6 +244,11 @@ $ ./bin/run-example org.apache.spark.examples.sql.streaming.JavaStructuredNetwor $ ./bin/spark-submit examples/src/main/python/sql/streaming/structured_network_wordcount.py localhost 9999 {% endhighlight %} +
    +{% highlight bash %} +$ ./bin/spark-submit examples/src/main/r/streaming/structured_network_wordcount.R localhost 9999 +{% endhighlight %} +
    Then, any lines typed in the terminal running the netcat server will be counted and printed on screen every second. It will look something like the following. @@ -325,6 +363,35 @@ Batch: 0 | spark| 1| +------+-----+ +------------------------------------------- +Batch: 1 +------------------------------------------- ++------+-----+ +| value|count| ++------+-----+ +|apache| 2| +| spark| 1| +|hadoop| 1| ++------+-----+ +... +{% endhighlight %} + +
    +{% highlight bash %} +# TERMINAL 2: RUNNING structured_network_wordcount.R + +$ ./bin/spark-submit examples/src/main/r/streaming/structured_network_wordcount.R localhost 9999 + +------------------------------------------- +Batch: 0 +------------------------------------------- ++------+-----+ +| value|count| ++------+-----+ +|apache| 1| +| spark| 1| ++------+-----+ + ------------------------------------------- Batch: 1 ------------------------------------------- @@ -409,14 +476,14 @@ to track the read position in the stream. The engine uses checkpointing and writ # API using Datasets and DataFrames Since Spark 2.0, DataFrames and Datasets can represent static, bounded data, as well as streaming, unbounded data. Similar to static Datasets/DataFrames, you can use the common entry point `SparkSession` -([Scala](api/scala/index.html#org.apache.spark.sql.SparkSession)/[Java](api/java/org/apache/spark/sql/SparkSession.html)/[Python](api/python/pyspark.sql.html#pyspark.sql.SparkSession) docs) +([Scala](api/scala/index.html#org.apache.spark.sql.SparkSession)/[Java](api/java/org/apache/spark/sql/SparkSession.html)/[Python](api/python/pyspark.sql.html#pyspark.sql.SparkSession)/[R](api/R/sparkR.session.html) docs) to create streaming DataFrames/Datasets from streaming sources, and apply the same operations on them as static DataFrames/Datasets. If you are not familiar with Datasets/DataFrames, you are strongly advised to familiarize yourself with them using the [DataFrame/Dataset Programming Guide](sql-programming-guide.html). ## Creating streaming DataFrames and streaming Datasets -Streaming DataFrames can be created through the `DataStreamReader` interface +Streaming DataFrames can be created through the `DataStreamReader` interface ([Scala](api/scala/index.html#org.apache.spark.sql.streaming.DataStreamReader)/[Java](api/java/org/apache/spark/sql/streaming/DataStreamReader.html)/[Python](api/python/pyspark.sql.html#pyspark.sql.streaming.DataStreamReader) docs) -returned by `SparkSession.readStream()`. Similar to the read interface for creating static DataFrame, you can specify the details of the source – data format, schema, options, etc. +returned by `SparkSession.readStream()`. In [R](api/R/read.stream.html), with the `read.stream()` method. Similar to the read interface for creating static DataFrame, you can specify the details of the source – data format, schema, options, etc. #### Input Sources In Spark 2.0, there are a few built-in sources. @@ -445,7 +512,8 @@ Here are the details of all the sources in Spark. path: path to the input directory, and common to all file formats.

    For file-format-specific options, see the related methods in DataStreamReader - (Scala/Java/Python). + (Scala/Java/Python/R). E.g. for "parquet" format options see DataStreamReader.parquet() Yes Supports glob paths, but does not support multiple comma-separated paths/globs. @@ -483,7 +551,7 @@ Here are some examples. {% highlight scala %} val spark: SparkSession = ... -// Read text from socket +// Read text from socket val socketDF = spark .readStream .format("socket") @@ -493,7 +561,7 @@ val socketDF = spark socketDF.isStreaming // Returns True for DataFrames that have streaming sources -socketDF.printSchema +socketDF.printSchema // Read all the csv files written atomically in a directory val userSchema = new StructType().add("name", "string").add("age", "integer") @@ -510,7 +578,7 @@ val csvDF = spark {% highlight java %} SparkSession spark = ... -// Read text from socket +// Read text from socket Dataset socketDF = spark .readStream() .format("socket") @@ -537,7 +605,7 @@ Dataset csvDF = spark {% highlight python %} spark = SparkSession. ... -# Read text from socket +# Read text from socket socketDF = spark \ .readStream \ .format("socket") \ @@ -547,7 +615,7 @@ socketDF = spark \ socketDF.isStreaming() # Returns True for DataFrames that have streaming sources -socketDF.printSchema() +socketDF.printSchema() # Read all the csv files written atomically in a directory userSchema = StructType().add("name", "string").add("age", "integer") @@ -558,6 +626,25 @@ csvDF = spark \ .csv("/path/to/directory") # Equivalent to format("csv").load("/path/to/directory") {% endhighlight %} +
    +
    + +{% highlight r %} +sparkR.session(...) + +# Read text from socket +socketDF <- read.stream("socket", host = hostname, port = port) + +isStreaming(socketDF) # Returns TRUE for SparkDataFrames that have streaming sources + +printSchema(socketDF) + +# Read all the csv files written atomically in a directory +schema <- structType(structField("name", "string"), + structField("age", "integer")) +csvDF <- read.stream("csv", path = "/path/to/directory", schema = schema, sep = ";") +{% endhighlight %} +
    @@ -638,12 +725,24 @@ ds.groupByKey((MapFunction) value -> value.getDeviceType(), df = ... # streaming DataFrame with IOT device data with schema { device: string, deviceType: string, signal: double, time: DateType } # Select the devices which have signal more than 10 -df.select("device").where("signal > 10") +df.select("device").where("signal > 10") # Running count of the number of updates for each device type df.groupBy("deviceType").count() {% endhighlight %} +
    + +{% highlight r %} +df <- ... # streaming DataFrame with IOT device data with schema { device: string, deviceType: string, signal: double, time: DateType } + +# Select the devices which have signal more than 10 +select(where(df, "signal > 10"), "device") + +# Running count of the number of updates for each device type +count(groupBy(df, "deviceType")) +{% endhighlight %} +
    ### Window Operations on Event Time @@ -840,7 +939,7 @@ Streaming DataFrames can be joined with static DataFrames to create new streamin {% highlight scala %} val staticDf = spark.read. ... -val streamingDf = spark.readStream. ... +val streamingDf = spark.readStream. ... streamingDf.join(staticDf, "type") // inner equi-join with a static DF streamingDf.join(staticDf, "type", "right_join") // right outer join with a static DF @@ -972,7 +1071,7 @@ Once you have defined the final result DataFrame/Dataset, all that is left is fo ([Scala](api/scala/index.html#org.apache.spark.sql.streaming.DataStreamWriter)/[Java](api/java/org/apache/spark/sql/streaming/DataStreamWriter.html)/[Python](api/python/pyspark.sql.html#pyspark.sql.streaming.DataStreamWriter) docs) returned through `Dataset.writeStream()`. You will have to specify one or more of the following in this interface. -- *Details of the output sink:* Data format, location, etc. +- *Details of the output sink:* Data format, location, etc. - *Output mode:* Specify what gets written to the output sink. @@ -1077,7 +1176,7 @@ Here is the compatibility matrix. #### Output Sinks There are a few types of built-in output sinks. -- **File sink** - Stores the output to a directory. +- **File sink** - Stores the output to a directory. {% highlight scala %} writeStream @@ -1145,7 +1244,8 @@ Here are the details of all the sinks in Spark. · "s3a://a/b/c/dataset.txt"

    For file-format-specific options, see the related methods in DataFrameWriter - (Scala/Java/Python). + (Scala/Java/Python/R). E.g. for "parquet" format options see DataFrameWriter.parquet() Yes @@ -1208,7 +1308,7 @@ noAggDF .option("checkpointLocation", "path/to/checkpoint/dir") .option("path", "path/to/destination/dir") .start() - + // ========== DF with aggregation ========== val aggDF = df.groupBy("device").count() @@ -1219,7 +1319,7 @@ aggDF .format("console") .start() -// Have all the aggregates in an in-memory table +// Have all the aggregates in an in-memory table aggDF .writeStream .queryName("aggregates") // this query name will be the table name @@ -1250,7 +1350,7 @@ noAggDF .option("checkpointLocation", "path/to/checkpoint/dir") .option("path", "path/to/destination/dir") .start(); - + // ========== DF with aggregation ========== Dataset aggDF = df.groupBy("device").count(); @@ -1261,7 +1361,7 @@ aggDF .format("console") .start(); -// Have all the aggregates in an in-memory table +// Have all the aggregates in an in-memory table aggDF .writeStream() .queryName("aggregates") // this query name will be the table name @@ -1292,7 +1392,7 @@ noAggDF \ .option("checkpointLocation", "path/to/checkpoint/dir") \ .option("path", "path/to/destination/dir") \ .start() - + # ========== DF with aggregation ========== aggDF = df.groupBy("device").count() @@ -1314,6 +1414,35 @@ aggDF \ spark.sql("select * from aggregates").show() # interactively query in-memory table {% endhighlight %} + +
    + +{% highlight r %} +# ========== DF with no aggregations ========== +noAggDF <- select(where(deviceDataDf, "signal > 10"), "device") + +# Print new data to console +write.stream(noAggDF, "console") + +# Write new data to Parquet files +write.stream(noAggDF, + "parquet", + path = "path/to/destination/dir", + checkpointLocation = "path/to/checkpoint/dir") + +# ========== DF with aggregation ========== +aggDF <- count(groupBy(df, "device")) + +# Print updated aggregations to console +write.stream(aggDF, "console", outputMode = "complete") + +# Have all the aggregates in an in memory table. The query name will be the table name +write.stream(aggDF, "memory", queryName = "aggregates", outputMode = "complete") + +# Interactively query in-memory table +head(sql("select * from aggregates")) +{% endhighlight %} +
    @@ -1351,7 +1480,7 @@ query.name // get the name of the auto-generated or user-specified name query.explain() // print detailed explanations of the query -query.stop() // stop the query +query.stop() // stop the query query.awaitTermination() // block until query is terminated, with stop() or with error @@ -1403,7 +1532,7 @@ query.name() # get the name of the auto-generated or user-specified name query.explain() # print detailed explanations of the query -query.stop() # stop the query +query.stop() # stop the query query.awaitTermination() # block until query is terminated, with stop() or with error @@ -1415,6 +1544,24 @@ query.lastProgress() # the most recent progress update of this streaming quer {% endhighlight %} + +
    + +{% highlight r %} +query <- write.stream(df, "console") # get the query object + +queryName(query) # get the name of the auto-generated or user-specified name + +explain(query) # print detailed explanations of the query + +stopQuery(query) # stop the query + +awaitTermination(query) # block until query is terminated, with stop() or with error + +lastProgress(query) # the most recent progress update of this streaming query + +{% endhighlight %} +
    @@ -1461,6 +1608,12 @@ spark.streams().get(id) # get a query object by its unique id spark.streams().awaitAnyTermination() # block until any one of them terminates {% endhighlight %} + +
    +{% highlight bash %} +Not available in R. +{% endhighlight %} +
    @@ -1644,6 +1797,58 @@ Will print something like the following. ''' {% endhighlight %} + +
    + +{% highlight r %} +query <- ... # a StreamingQuery +lastProgress(query) + +''' +Will print something like the following. + +{ + "id" : "8c57e1ec-94b5-4c99-b100-f694162df0b9", + "runId" : "ae505c5a-a64e-4896-8c28-c7cbaf926f16", + "name" : null, + "timestamp" : "2017-04-26T08:27:28.835Z", + "numInputRows" : 0, + "inputRowsPerSecond" : 0.0, + "processedRowsPerSecond" : 0.0, + "durationMs" : { + "getOffset" : 0, + "triggerExecution" : 1 + }, + "stateOperators" : [ { + "numRowsTotal" : 4, + "numRowsUpdated" : 0 + } ], + "sources" : [ { + "description" : "TextSocketSource[host: localhost, port: 9999]", + "startOffset" : 1, + "endOffset" : 1, + "numInputRows" : 0, + "inputRowsPerSecond" : 0.0, + "processedRowsPerSecond" : 0.0 + } ], + "sink" : { + "description" : "org.apache.spark.sql.execution.streaming.ConsoleSink@76b37531" + } +} +''' + +status(query) +''' +Will print something like the following. + +{ + "message" : "Waiting for data to arrive", + "isDataAvailable" : false, + "isTriggerActive" : false +} +''' +{% endhighlight %} +
    @@ -1703,11 +1908,17 @@ spark.streams().addListener(new StreamingQueryListener() { Not available in Python. {% endhighlight %} + +
    +{% highlight bash %} +Not available in R. +{% endhighlight %} +
    ## Recovering from Failures with Checkpointing -In case of a failure or intentional shutdown, you can recover the previous progress and state of a previous query, and continue where it left off. This is done using checkpointing and write ahead logs. You can configure a query with a checkpoint location, and the query will save all the progress information (i.e. range of offsets processed in each trigger) and the running aggregates (e.g. word counts in the [quick example](#quick-example)) to the checkpoint location. This checkpoint location has to be a path in an HDFS compatible file system, and can be set as an option in the DataStreamWriter when [starting a query](#starting-streaming-queries). +In case of a failure or intentional shutdown, you can recover the previous progress and state of a previous query, and continue where it left off. This is done using checkpointing and write ahead logs. You can configure a query with a checkpoint location, and the query will save all the progress information (i.e. range of offsets processed in each trigger) and the running aggregates (e.g. word counts in the [quick example](#quick-example)) to the checkpoint location. This checkpoint location has to be a path in an HDFS compatible file system, and can be set as an option in the DataStreamWriter when [starting a query](#starting-streaming-queries).
    @@ -1745,20 +1956,18 @@ aggDF \ .start() {% endhighlight %} +
    +
    + +{% highlight r %} +write.stream(aggDF, "memory", outputMode = "complete", checkpointLocation = "path/to/HDFS/dir") +{% endhighlight %} +
    # Where to go from here -- Examples: See and run the -[Scala]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/scala/org/apache/spark/examples/sql/streaming)/[Java]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/java/org/apache/spark/examples/sql/streaming)/[Python]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/python/sql/streaming) +- Examples: See and run the +[Scala]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/scala/org/apache/spark/examples/sql/streaming)/[Java]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/java/org/apache/spark/examples/sql/streaming)/[Python]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/python/sql/streaming)/[R]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/r/streaming) examples. - Spark Summit 2016 Talk - [A Deep Dive into Structured Streaming](https://spark-summit.org/2016/events/a-deep-dive-into-structured-streaming/) - - - - - - - - - diff --git a/examples/src/main/r/streaming/structured_network_wordcount.R b/examples/src/main/r/streaming/structured_network_wordcount.R new file mode 100644 index 000000000000..cda18ebc072e --- /dev/null +++ b/examples/src/main/r/streaming/structured_network_wordcount.R @@ -0,0 +1,57 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Counts words in UTF8 encoded, '\n' delimited text received from the network. + +# To run this on your local machine, you need to first run a Netcat server +# $ nc -lk 9999 +# and then run the example +# ./bin/spark-submit examples/src/main/r/streaming/structured_network_wordcount.R localhost 9999 + +# Load SparkR library into your R session +library(SparkR) + +# Initialize SparkSession +sparkR.session(appName = "SparkR-Streaming-structured-network-wordcount-example") + +args <- commandArgs(trailing = TRUE) + +if (length(args) != 2) { + print("Usage: structured_network_wordcount.R ") + print(" and describe the TCP server that Structured Streaming") + print("would connect to receive data.") + q("no") +} + +hostname <- args[[1]] +port <- as.integer(args[[2]]) + +# Create DataFrame representing the stream of input lines from connection to localhost:9999 +lines <- read.stream("socket", host = hostname, port = port) + +# Split the lines into words +words <- selectExpr(lines, "explode(split(value, ' ')) as word") + +# Generate running word count +wordCounts <- count(groupBy(words, "word")) + +# Start running the query that prints the running counts to the console +query <- write.stream(wordCounts, "console", outputMode = "complete") + +awaitTermination(query) + +sparkR.session.stop() From 9c36aa27919fb7625e388f5c3c90af62ef902b24 Mon Sep 17 00:00:00 2001 From: zero323 Date: Thu, 4 May 2017 01:41:36 -0700 Subject: [PATCH 420/512] [SPARK-20585][SPARKR] R generic hint support ## What changes were proposed in this pull request? Adds support for generic hints on `SparkDataFrame` ## How was this patch tested? Unit tests, `check-cran.sh` Author: zero323 Closes #17851 from zero323/SPARK-20585. --- R/pkg/NAMESPACE | 1 + R/pkg/R/DataFrame.R | 30 +++++++++++++++++++++++ R/pkg/R/generics.R | 4 +++ R/pkg/inst/tests/testthat/test_sparkSQL.R | 12 +++++++++ 4 files changed, 47 insertions(+) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 7ecd168137e8..daa168c87ecd 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -123,6 +123,7 @@ exportMethods("arrange", "group_by", "groupBy", "head", + "hint", "insertInto", "intersect", "isLocal", diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 7e57ba6287bb..1c8869202f67 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -3715,3 +3715,33 @@ setMethod("rollup", sgd <- callJMethod(x@sdf, "rollup", jcol) groupedData(sgd) }) + +#' hint +#' +#' Specifies execution plan hint and return a new SparkDataFrame. +#' +#' @param x a SparkDataFrame. +#' @param name a name of the hint. +#' @param ... optional parameters for the hint. +#' @return A SparkDataFrame. +#' @family SparkDataFrame functions +#' @aliases hint,SparkDataFrame,character-method +#' @rdname hint +#' @name hint +#' @export +#' @examples +#' \dontrun{ +#' df <- createDataFrame(mtcars) +#' avg_mpg <- mean(groupBy(createDataFrame(mtcars), "cyl"), "mpg") +#' +#' head(join(df, hint(avg_mpg, "broadcast"), df$cyl == avg_mpg$cyl)) +#' } +#' @note hint since 2.2.0 +setMethod("hint", + signature(x = "SparkDataFrame", name = "character"), + function(x, name, ...) { + parameters <- list(...) + stopifnot(all(sapply(parameters, is.character))) + jdf <- callJMethod(x@sdf, "hint", name, parameters) + dataFrame(jdf) + }) diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index e02d46426a5a..56ef1bee9353 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -576,6 +576,10 @@ setGeneric("group_by", function(x, ...) { standardGeneric("group_by") }) #' @export setGeneric("groupBy", function(x, ...) { standardGeneric("groupBy") }) +#' @rdname hint +#' @export +setGeneric("hint", function(x, name, ...) { standardGeneric("hint") }) + #' @rdname insertInto #' @export setGeneric("insertInto", function(x, tableName, ...) { standardGeneric("insertInto") }) diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index a7bb3265d92d..82007a534849 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -2182,6 +2182,18 @@ test_that("join(), crossJoin() and merge() on a DataFrame", { unlink(jsonPath2) unlink(jsonPath3) + + # Join with broadcast hint + df1 <- sql("SELECT * FROM range(10e10)") + df2 <- sql("SELECT * FROM range(10e10)") + + execution_plan <- capture.output(explain(join(df1, df2, df1$id == df2$id))) + expect_false(any(grepl("BroadcastHashJoin", execution_plan))) + + execution_plan_hint <- capture.output( + explain(join(df1, hint(df2, "broadcast"), df1$id == df2$id)) + ) + expect_true(any(grepl("BroadcastHashJoin", execution_plan_hint))) }) test_that("toJSON() on DataFrame", { From f21897fc157ce467f2b2edb5631b31787883accd Mon Sep 17 00:00:00 2001 From: zero323 Date: Thu, 4 May 2017 01:51:37 -0700 Subject: [PATCH 421/512] [SPARK-20544][SPARKR] R wrapper for input_file_name ## What changes were proposed in this pull request? Adds wrapper for `o.a.s.sql.functions.input_file_name` ## How was this patch tested? Existing unit tests, additional unit tests, `check-cran.sh`. Author: zero323 Closes #17818 from zero323/SPARK-20544. --- R/pkg/NAMESPACE | 1 + R/pkg/R/functions.R | 21 +++++++++++++++++++++ R/pkg/R/generics.R | 6 ++++++ R/pkg/inst/tests/testthat/test_sparkSQL.R | 5 +++++ 4 files changed, 33 insertions(+) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index daa168c87ecd..ba0fe7708bcc 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -258,6 +258,7 @@ exportMethods("%<=>%", "hypot", "ifelse", "initcap", + "input_file_name", "instr", "isNaN", "isNotNull", diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 3d47b09ce551..5f9d11475c94 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -3975,3 +3975,24 @@ setMethod("grouping_id", jc <- callJStatic("org.apache.spark.sql.functions", "grouping_id", jcols) column(jc) }) + +#' input_file_name +#' +#' Creates a string column with the input file name for a given row +#' +#' @rdname input_file_name +#' @name input_file_name +#' @family normal_funcs +#' @aliases input_file_name,missing-method +#' @export +#' @examples \dontrun{ +#' df <- read.text("README.md") +#' +#' head(select(df, input_file_name())) +#' } +#' @note input_file_name since 2.3.0 +setMethod("input_file_name", signature("missing"), + function() { + jc <- callJStatic("org.apache.spark.sql.functions", "input_file_name") + column(jc) + }) diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 56ef1bee9353..e835ef3e4f40 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -1080,6 +1080,12 @@ setGeneric("hypot", function(y, x) { standardGeneric("hypot") }) #' @export setGeneric("initcap", function(x) { standardGeneric("initcap") }) +#' @param x empty. Should be used with no argument. +#' @rdname input_file_name +#' @export +setGeneric("input_file_name", + function(x = "missing") { standardGeneric("input_file_name") }) + #' @rdname instr #' @export setGeneric("instr", function(y, x) { standardGeneric("instr") }) diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index 82007a534849..47cc34a6c5b7 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -1402,6 +1402,11 @@ test_that("column functions", { expect_equal(collect(df2)[[3, 1]], FALSE) expect_equal(collect(df2)[[3, 2]], TRUE) + # Test that input_file_name() + actual_names <- sort(collect(distinct(select(df, input_file_name())))) + expect_equal(length(actual_names), 1) + expect_equal(basename(actual_names[1, 1]), basename(jsonPath)) + df3 <- select(df, between(df$name, c("Apache", "Spark"))) expect_equal(collect(df3)[[1, 1]], TRUE) expect_equal(collect(df3)[[2, 1]], FALSE) From 57b64703e66ec8490d8d9dbf6beebc160a61ec29 Mon Sep 17 00:00:00 2001 From: Felix Cheung Date: Thu, 4 May 2017 01:54:59 -0700 Subject: [PATCH 422/512] [SPARK-20571][SPARKR][SS] Flaky Structured Streaming tests ## What changes were proposed in this pull request? Make tests more reliable by having it till processed. Increasing timeout value might help but ultimately the flakiness from processing delay when Jenkins is hard to account for. This isn't an actual public API supported ## How was this patch tested? unit tests Author: Felix Cheung Closes #17857 from felixcheung/rsstestrelia. --- R/pkg/inst/tests/testthat/test_streaming.R | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/R/pkg/inst/tests/testthat/test_streaming.R b/R/pkg/inst/tests/testthat/test_streaming.R index 884399102430..91df7ac6f984 100644 --- a/R/pkg/inst/tests/testthat/test_streaming.R +++ b/R/pkg/inst/tests/testthat/test_streaming.R @@ -55,10 +55,12 @@ test_that("read.stream, write.stream, awaitTermination, stopQuery", { q <- write.stream(counts, "memory", queryName = "people", outputMode = "complete") expect_false(awaitTermination(q, 5 * 1000)) + callJMethod(q@ssq, "processAllAvailable") expect_equal(head(sql("SELECT count(*) FROM people"))[[1]], 3) writeLines(mockLinesNa, jsonPathNa) awaitTermination(q, 5 * 1000) + callJMethod(q@ssq, "processAllAvailable") expect_equal(head(sql("SELECT count(*) FROM people"))[[1]], 6) stopQuery(q) @@ -75,6 +77,7 @@ test_that("print from explain, lastProgress, status, isActive", { q <- write.stream(counts, "memory", queryName = "people2", outputMode = "complete") awaitTermination(q, 5 * 1000) + callJMethod(q@ssq, "processAllAvailable") expect_equal(capture.output(explain(q))[[1]], "== Physical Plan ==") expect_true(any(grepl("\"description\" : \"MemorySink\"", capture.output(lastProgress(q))))) @@ -99,6 +102,7 @@ test_that("Stream other format", { q <- write.stream(counts, "memory", queryName = "people3", outputMode = "complete") expect_false(awaitTermination(q, 5 * 1000)) + callJMethod(q@ssq, "processAllAvailable") expect_equal(head(sql("SELECT count(*) FROM people3"))[[1]], 3) expect_equal(queryName(q), "people3") From c5dceb8c65545169bc96628140b5acdaa85dd226 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Thu, 4 May 2017 17:56:43 +0800 Subject: [PATCH 423/512] [SPARK-20047][FOLLOWUP][ML] Constrained Logistic Regression follow up ## What changes were proposed in this pull request? Address some minor comments for #17715: * Put bound-constrained optimization params under expertParams. * Update some docs. ## How was this patch tested? Existing tests. Author: Yanbo Liang Closes #17829 from yanboliang/spark-20047-followup. --- .../classification/LogisticRegression.scala | 54 ++++++++++++------- 1 file changed, 35 insertions(+), 19 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index d7dde329ed00..42dc7fbebe4c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -183,14 +183,15 @@ private[classification] trait LogisticRegressionParams extends ProbabilisticClas * The bound matrix must be compatible with the shape (1, number of features) for binomial * regression, or (number of classes, number of features) for multinomial regression. * Otherwise, it throws exception. + * Default is none. * - * @group param + * @group expertParam */ @Since("2.2.0") val lowerBoundsOnCoefficients: Param[Matrix] = new Param(this, "lowerBoundsOnCoefficients", "The lower bounds on coefficients if fitting under bound constrained optimization.") - /** @group getParam */ + /** @group expertGetParam */ @Since("2.2.0") def getLowerBoundsOnCoefficients: Matrix = $(lowerBoundsOnCoefficients) @@ -199,14 +200,15 @@ private[classification] trait LogisticRegressionParams extends ProbabilisticClas * The bound matrix must be compatible with the shape (1, number of features) for binomial * regression, or (number of classes, number of features) for multinomial regression. * Otherwise, it throws exception. + * Default is none. * - * @group param + * @group expertParam */ @Since("2.2.0") val upperBoundsOnCoefficients: Param[Matrix] = new Param(this, "upperBoundsOnCoefficients", "The upper bounds on coefficients if fitting under bound constrained optimization.") - /** @group getParam */ + /** @group expertGetParam */ @Since("2.2.0") def getUpperBoundsOnCoefficients: Matrix = $(upperBoundsOnCoefficients) @@ -214,14 +216,15 @@ private[classification] trait LogisticRegressionParams extends ProbabilisticClas * The lower bounds on intercepts if fitting under bound constrained optimization. * The bounds vector size must be equal with 1 for binomial regression, or the number * of classes for multinomial regression. Otherwise, it throws exception. + * Default is none. * - * @group param + * @group expertParam */ @Since("2.2.0") val lowerBoundsOnIntercepts: Param[Vector] = new Param(this, "lowerBoundsOnIntercepts", "The lower bounds on intercepts if fitting under bound constrained optimization.") - /** @group getParam */ + /** @group expertGetParam */ @Since("2.2.0") def getLowerBoundsOnIntercepts: Vector = $(lowerBoundsOnIntercepts) @@ -229,14 +232,15 @@ private[classification] trait LogisticRegressionParams extends ProbabilisticClas * The upper bounds on intercepts if fitting under bound constrained optimization. * The bound vector size must be equal with 1 for binomial regression, or the number * of classes for multinomial regression. Otherwise, it throws exception. + * Default is none. * - * @group param + * @group expertParam */ @Since("2.2.0") val upperBoundsOnIntercepts: Param[Vector] = new Param(this, "upperBoundsOnIntercepts", "The upper bounds on intercepts if fitting under bound constrained optimization.") - /** @group getParam */ + /** @group expertGetParam */ @Since("2.2.0") def getUpperBoundsOnIntercepts: Vector = $(upperBoundsOnIntercepts) @@ -256,7 +260,7 @@ private[classification] trait LogisticRegressionParams extends ProbabilisticClas } if (!$(fitIntercept)) { require(!isSet(lowerBoundsOnIntercepts) && !isSet(upperBoundsOnIntercepts), - "Pls don't set bounds on intercepts if fitting without intercept.") + "Please don't set bounds on intercepts if fitting without intercept.") } super.validateAndTransformSchema(schema, fitting, featuresDataType) } @@ -393,7 +397,7 @@ class LogisticRegression @Since("1.2.0") ( /** * Set the lower bounds on coefficients if fitting under bound constrained optimization. * - * @group setParam + * @group expertSetParam */ @Since("2.2.0") def setLowerBoundsOnCoefficients(value: Matrix): this.type = set(lowerBoundsOnCoefficients, value) @@ -401,7 +405,7 @@ class LogisticRegression @Since("1.2.0") ( /** * Set the upper bounds on coefficients if fitting under bound constrained optimization. * - * @group setParam + * @group expertSetParam */ @Since("2.2.0") def setUpperBoundsOnCoefficients(value: Matrix): this.type = set(upperBoundsOnCoefficients, value) @@ -409,7 +413,7 @@ class LogisticRegression @Since("1.2.0") ( /** * Set the lower bounds on intercepts if fitting under bound constrained optimization. * - * @group setParam + * @group expertSetParam */ @Since("2.2.0") def setLowerBoundsOnIntercepts(value: Vector): this.type = set(lowerBoundsOnIntercepts, value) @@ -417,7 +421,7 @@ class LogisticRegression @Since("1.2.0") ( /** * Set the upper bounds on intercepts if fitting under bound constrained optimization. * - * @group setParam + * @group expertSetParam */ @Since("2.2.0") def setUpperBoundsOnIntercepts(value: Vector): this.type = set(upperBoundsOnIntercepts, value) @@ -427,28 +431,40 @@ class LogisticRegression @Since("1.2.0") ( numFeatures: Int): Unit = { if (isSet(lowerBoundsOnCoefficients)) { require($(lowerBoundsOnCoefficients).numRows == numCoefficientSets && - $(lowerBoundsOnCoefficients).numCols == numFeatures) + $(lowerBoundsOnCoefficients).numCols == numFeatures, + "The shape of LowerBoundsOnCoefficients must be compatible with (1, number of features) " + + "for binomial regression, or (number of classes, number of features) for multinomial " + + "regression, but found: " + + s"(${getLowerBoundsOnCoefficients.numRows}, ${getLowerBoundsOnCoefficients.numCols}).") } if (isSet(upperBoundsOnCoefficients)) { require($(upperBoundsOnCoefficients).numRows == numCoefficientSets && - $(upperBoundsOnCoefficients).numCols == numFeatures) + $(upperBoundsOnCoefficients).numCols == numFeatures, + "The shape of upperBoundsOnCoefficients must be compatible with (1, number of features) " + + "for binomial regression, or (number of classes, number of features) for multinomial " + + "regression, but found: " + + s"(${getUpperBoundsOnCoefficients.numRows}, ${getUpperBoundsOnCoefficients.numCols}).") } if (isSet(lowerBoundsOnIntercepts)) { - require($(lowerBoundsOnIntercepts).size == numCoefficientSets) + require($(lowerBoundsOnIntercepts).size == numCoefficientSets, "The size of " + + "lowerBoundsOnIntercepts must be equal with 1 for binomial regression, or the number of " + + s"classes for multinomial regression, but found: ${getLowerBoundsOnIntercepts.size}.") } if (isSet(upperBoundsOnIntercepts)) { - require($(upperBoundsOnIntercepts).size == numCoefficientSets) + require($(upperBoundsOnIntercepts).size == numCoefficientSets, "The size of " + + "upperBoundsOnIntercepts must be equal with 1 for binomial regression, or the number of " + + s"classes for multinomial regression, but found: ${getUpperBoundsOnIntercepts.size}.") } if (isSet(lowerBoundsOnCoefficients) && isSet(upperBoundsOnCoefficients)) { require($(lowerBoundsOnCoefficients).toArray.zip($(upperBoundsOnCoefficients).toArray) - .forall(x => x._1 <= x._2), "LowerBoundsOnCoefficients should always " + + .forall(x => x._1 <= x._2), "LowerBoundsOnCoefficients should always be " + "less than or equal to upperBoundsOnCoefficients, but found: " + s"lowerBoundsOnCoefficients = $getLowerBoundsOnCoefficients, " + s"upperBoundsOnCoefficients = $getUpperBoundsOnCoefficients.") } if (isSet(lowerBoundsOnIntercepts) && isSet(upperBoundsOnIntercepts)) { require($(lowerBoundsOnIntercepts).toArray.zip($(upperBoundsOnIntercepts).toArray) - .forall(x => x._1 <= x._2), "LowerBoundsOnIntercepts should always " + + .forall(x => x._1 <= x._2), "LowerBoundsOnIntercepts should always be " + "less than or equal to upperBoundsOnIntercepts, but found: " + s"lowerBoundsOnIntercepts = $getLowerBoundsOnIntercepts, " + s"upperBoundsOnIntercepts = $getUpperBoundsOnIntercepts.") From bfc8c79c8dda7668cfded2a728424853a26da035 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Thu, 4 May 2017 21:04:15 +0800 Subject: [PATCH 424/512] [SPARK-20566][SQL] ColumnVector should support `appendFloats` for array ## What changes were proposed in this pull request? This PR aims to add a missing `appendFloats` API for array into **ColumnVector** class. For double type, there is `appendDoubles` for array [here](https://github.com/apache/spark/blob/master/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java#L818-L824). ## How was this patch tested? Pass the Jenkins with a newly added test case. Author: Dongjoon Hyun Closes #17836 from dongjoon-hyun/SPARK-20566. --- .../execution/vectorized/ColumnVector.java | 8 + .../vectorized/ColumnarBatchSuite.scala | 256 ++++++++++++++++-- 2 files changed, 240 insertions(+), 24 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java index b105e60a2d34..ad267ab0c9c4 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java @@ -801,6 +801,14 @@ public final int appendFloats(int count, float v) { return result; } + public final int appendFloats(int length, float[] src, int offset) { + reserve(elementsAppended + length); + int result = elementsAppended; + putFloats(elementsAppended, length, src, offset); + elementsAppended += length; + return result; + } + public final int appendDouble(double v) { reserve(elementsAppended + 1); putDouble(elementsAppended, v); diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala index 8184d7d909f4..e48e3f640290 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala @@ -41,24 +41,49 @@ class ColumnarBatchSuite extends SparkFunSuite { val column = ColumnVector.allocate(1024, IntegerType, memMode) var idx = 0 assert(column.anyNullsSet() == false) + assert(column.numNulls() == 0) + + column.appendNotNull() + reference += false + assert(column.anyNullsSet() == false) + assert(column.numNulls() == 0) + + column.appendNotNulls(3) + (1 to 3).foreach(_ => reference += false) + assert(column.anyNullsSet() == false) + assert(column.numNulls() == 0) + + column.appendNull() + reference += true + assert(column.anyNullsSet()) + assert(column.numNulls() == 1) + + column.appendNulls(3) + (1 to 3).foreach(_ => reference += true) + assert(column.anyNullsSet()) + assert(column.numNulls() == 4) + + idx = column.elementsAppended column.putNotNull(idx) reference += false idx += 1 - assert(column.anyNullsSet() == false) + assert(column.anyNullsSet()) + assert(column.numNulls() == 4) column.putNull(idx) reference += true idx += 1 - assert(column.anyNullsSet() == true) - assert(column.numNulls() == 1) + assert(column.anyNullsSet()) + assert(column.numNulls() == 5) column.putNulls(idx, 3) reference += true reference += true reference += true idx += 3 - assert(column.anyNullsSet() == true) + assert(column.anyNullsSet()) + assert(column.numNulls() == 8) column.putNotNulls(idx, 4) reference += false @@ -66,8 +91,8 @@ class ColumnarBatchSuite extends SparkFunSuite { reference += false reference += false idx += 4 - assert(column.anyNullsSet() == true) - assert(column.numNulls() == 4) + assert(column.anyNullsSet()) + assert(column.numNulls() == 8) reference.zipWithIndex.foreach { v => assert(v._1 == column.isNullAt(v._2)) @@ -85,9 +110,26 @@ class ColumnarBatchSuite extends SparkFunSuite { val reference = mutable.ArrayBuffer.empty[Byte] val column = ColumnVector.allocate(1024, ByteType, memMode) - var idx = 0 - val values = (1 :: 2 :: 3 :: 4 :: 5 :: Nil).map(_.toByte).toArray + var values = (10 :: 20 :: 30 :: 40 :: 50 :: Nil).map(_.toByte).toArray + column.appendBytes(2, values, 0) + reference += 10.toByte + reference += 20.toByte + + column.appendBytes(3, values, 2) + reference += 30.toByte + reference += 40.toByte + reference += 50.toByte + + column.appendBytes(6, 60.toByte) + (1 to 6).foreach(_ => reference += 60.toByte) + + column.appendByte(70.toByte) + reference += 70.toByte + + var idx = column.elementsAppended + + values = (1 :: 2 :: 3 :: 4 :: 5 :: Nil).map(_.toByte).toArray column.putBytes(idx, 2, values, 0) reference += 1 reference += 2 @@ -126,9 +168,26 @@ class ColumnarBatchSuite extends SparkFunSuite { val reference = mutable.ArrayBuffer.empty[Short] val column = ColumnVector.allocate(1024, ShortType, memMode) - var idx = 0 - val values = (1 :: 2 :: 3 :: 4 :: 5 :: Nil).map(_.toShort).toArray + var values = (10 :: 20 :: 30 :: 40 :: 50 :: Nil).map(_.toShort).toArray + column.appendShorts(2, values, 0) + reference += 10.toShort + reference += 20.toShort + + column.appendShorts(3, values, 2) + reference += 30.toShort + reference += 40.toShort + reference += 50.toShort + + column.appendShorts(6, 60.toShort) + (1 to 6).foreach(_ => reference += 60.toShort) + + column.appendShort(70.toShort) + reference += 70.toShort + + var idx = column.elementsAppended + + values = (1 :: 2 :: 3 :: 4 :: 5 :: Nil).map(_.toShort).toArray column.putShorts(idx, 2, values, 0) reference += 1 reference += 2 @@ -189,9 +248,26 @@ class ColumnarBatchSuite extends SparkFunSuite { val reference = mutable.ArrayBuffer.empty[Int] val column = ColumnVector.allocate(1024, IntegerType, memMode) - var idx = 0 - val values = (1 :: 2 :: 3 :: 4 :: 5 :: Nil).toArray + var values = (10 :: 20 :: 30 :: 40 :: 50 :: Nil).toArray + column.appendInts(2, values, 0) + reference += 10 + reference += 20 + + column.appendInts(3, values, 2) + reference += 30 + reference += 40 + reference += 50 + + column.appendInts(6, 60) + (1 to 6).foreach(_ => reference += 60) + + column.appendInt(70) + reference += 70 + + var idx = column.elementsAppended + + values = (1 :: 2 :: 3 :: 4 :: 5 :: Nil).toArray column.putInts(idx, 2, values, 0) reference += 1 reference += 2 @@ -257,9 +333,26 @@ class ColumnarBatchSuite extends SparkFunSuite { val reference = mutable.ArrayBuffer.empty[Long] val column = ColumnVector.allocate(1024, LongType, memMode) - var idx = 0 - val values = (1L :: 2L :: 3L :: 4L :: 5L :: Nil).toArray + var values = (10L :: 20L :: 30L :: 40L :: 50L :: Nil).toArray + column.appendLongs(2, values, 0) + reference += 10L + reference += 20L + + column.appendLongs(3, values, 2) + reference += 30L + reference += 40L + reference += 50L + + column.appendLongs(6, 60L) + (1 to 6).foreach(_ => reference += 60L) + + column.appendLong(70L) + reference += 70L + + var idx = column.elementsAppended + + values = (1L :: 2L :: 3L :: 4L :: 5L :: Nil).toArray column.putLongs(idx, 2, values, 0) reference += 1 reference += 2 @@ -320,6 +413,97 @@ class ColumnarBatchSuite extends SparkFunSuite { }} } + test("Float APIs") { + (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => { + val seed = System.currentTimeMillis() + val random = new Random(seed) + val reference = mutable.ArrayBuffer.empty[Float] + + val column = ColumnVector.allocate(1024, FloatType, memMode) + + var values = (.1f :: .2f :: .3f :: .4f :: .5f :: Nil).toArray + column.appendFloats(2, values, 0) + reference += .1f + reference += .2f + + column.appendFloats(3, values, 2) + reference += .3f + reference += .4f + reference += .5f + + column.appendFloats(6, .6f) + (1 to 6).foreach(_ => reference += .6f) + + column.appendFloat(.7f) + reference += .7f + + var idx = column.elementsAppended + + values = (1.0f :: 2.0f :: 3.0f :: 4.0f :: 5.0f :: Nil).toArray + column.putFloats(idx, 2, values, 0) + reference += 1.0f + reference += 2.0f + idx += 2 + + column.putFloats(idx, 3, values, 2) + reference += 3.0f + reference += 4.0f + reference += 5.0f + idx += 3 + + val buffer = new Array[Byte](8) + Platform.putFloat(buffer, Platform.BYTE_ARRAY_OFFSET, 2.234f) + Platform.putFloat(buffer, Platform.BYTE_ARRAY_OFFSET + 4, 1.123f) + + if (ByteOrder.nativeOrder().equals(ByteOrder.BIG_ENDIAN)) { + // Ensure array contains Little Endian floats + val bb = ByteBuffer.wrap(buffer).order(ByteOrder.LITTLE_ENDIAN) + Platform.putFloat(buffer, Platform.BYTE_ARRAY_OFFSET, bb.getFloat(0)) + Platform.putFloat(buffer, Platform.BYTE_ARRAY_OFFSET + 4, bb.getFloat(4)) + } + + column.putFloats(idx, 1, buffer, 4) + column.putFloats(idx + 1, 1, buffer, 0) + reference += 1.123f + reference += 2.234f + idx += 2 + + column.putFloats(idx, 2, buffer, 0) + reference += 2.234f + reference += 1.123f + idx += 2 + + while (idx < column.capacity) { + val single = random.nextBoolean() + if (single) { + val v = random.nextFloat() + column.putFloat(idx, v) + reference += v + idx += 1 + } else { + val n = math.min(random.nextInt(column.capacity / 20), column.capacity - idx) + val v = random.nextFloat() + column.putFloats(idx, n, v) + var i = 0 + while (i < n) { + reference += v + i += 1 + } + idx += n + } + } + + reference.zipWithIndex.foreach { v => + assert(v._1 == column.getFloat(v._2), "Seed = " + seed + " MemMode=" + memMode) + if (memMode == MemoryMode.OFF_HEAP) { + val addr = column.valuesNativeAddress() + assert(v._1 == Platform.getFloat(null, addr + 4 * v._2)) + } + } + column.close + }} + } + test("Double APIs") { (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => { val seed = System.currentTimeMillis() @@ -327,9 +511,26 @@ class ColumnarBatchSuite extends SparkFunSuite { val reference = mutable.ArrayBuffer.empty[Double] val column = ColumnVector.allocate(1024, DoubleType, memMode) - var idx = 0 - val values = (1.0 :: 2.0 :: 3.0 :: 4.0 :: 5.0 :: Nil).toArray + var values = (.1 :: .2 :: .3 :: .4 :: .5 :: Nil).toArray + column.appendDoubles(2, values, 0) + reference += .1 + reference += .2 + + column.appendDoubles(3, values, 2) + reference += .3 + reference += .4 + reference += .5 + + column.appendDoubles(6, .6) + (1 to 6).foreach(_ => reference += .6) + + column.appendDouble(.7) + reference += .7 + + var idx = column.elementsAppended + + values = (1.0 :: 2.0 :: 3.0 :: 4.0 :: 5.0 :: Nil).toArray column.putDoubles(idx, 2, values, 0) reference += 1.0 reference += 2.0 @@ -346,8 +547,8 @@ class ColumnarBatchSuite extends SparkFunSuite { Platform.putDouble(buffer, Platform.BYTE_ARRAY_OFFSET + 8, 1.123) if (ByteOrder.nativeOrder().equals(ByteOrder.BIG_ENDIAN)) { - // Ensure array contains Liitle Endian doubles - var bb = ByteBuffer.wrap(buffer).order(ByteOrder.LITTLE_ENDIAN) + // Ensure array contains Little Endian doubles + val bb = ByteBuffer.wrap(buffer).order(ByteOrder.LITTLE_ENDIAN) Platform.putDouble(buffer, Platform.BYTE_ARRAY_OFFSET, bb.getDouble(0)) Platform.putDouble(buffer, Platform.BYTE_ARRAY_OFFSET + 8, bb.getDouble(8)) } @@ -400,40 +601,47 @@ class ColumnarBatchSuite extends SparkFunSuite { val column = ColumnVector.allocate(6, BinaryType, memMode) assert(column.arrayData().elementsAppended == 0) - var idx = 0 + + val str = "string" + column.appendByteArray(str.getBytes(StandardCharsets.UTF_8), + 0, str.getBytes(StandardCharsets.UTF_8).length) + reference += str + assert(column.arrayData().elementsAppended == 6) + + var idx = column.elementsAppended val values = ("Hello" :: "abc" :: Nil).toArray column.putByteArray(idx, values(0).getBytes(StandardCharsets.UTF_8), 0, values(0).getBytes(StandardCharsets.UTF_8).length) reference += values(0) idx += 1 - assert(column.arrayData().elementsAppended == 5) + assert(column.arrayData().elementsAppended == 11) column.putByteArray(idx, values(1).getBytes(StandardCharsets.UTF_8), 0, values(1).getBytes(StandardCharsets.UTF_8).length) reference += values(1) idx += 1 - assert(column.arrayData().elementsAppended == 8) + assert(column.arrayData().elementsAppended == 14) // Just put llo val offset = column.putByteArray(idx, values(0).getBytes(StandardCharsets.UTF_8), 2, values(0).getBytes(StandardCharsets.UTF_8).length - 2) reference += "llo" idx += 1 - assert(column.arrayData().elementsAppended == 11) + assert(column.arrayData().elementsAppended == 17) // Put the same "ll" at offset. This should not allocate more memory in the column. column.putArray(idx, offset, 2) reference += "ll" idx += 1 - assert(column.arrayData().elementsAppended == 11) + assert(column.arrayData().elementsAppended == 17) // Put a long string val s = "abcdefghijklmnopqrstuvwxyz" column.putByteArray(idx, (s + s).getBytes(StandardCharsets.UTF_8)) reference += (s + s) idx += 1 - assert(column.arrayData().elementsAppended == 11 + (s + s).length) + assert(column.arrayData().elementsAppended == 17 + (s + s).length) reference.zipWithIndex.foreach { v => assert(v._1.length == column.getArrayLength(v._2), "MemoryMode=" + memMode) From 0d16faab90e4cd1f73c5b749dbda7bc2a400b26f Mon Sep 17 00:00:00 2001 From: Wayne Zhang Date: Fri, 5 May 2017 10:23:58 +0800 Subject: [PATCH 425/512] [SPARK-20574][ML] Allow Bucketizer to handle non-Double numeric column ## What changes were proposed in this pull request? Bucketizer currently requires input column to be Double, but the logic should work on any numeric data types. Many practical problems have integer/float data types, and it could get very tedious to manually cast them into Double before calling bucketizer. This PR extends bucketizer to handle all numeric types. ## How was this patch tested? New test. Author: Wayne Zhang Closes #17840 from actuaryzhang/bucketizer. --- .../apache/spark/ml/feature/Bucketizer.scala | 4 +-- .../spark/ml/feature/BucketizerSuite.scala | 25 +++++++++++++++++++ 2 files changed, 27 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala index d1f3b2af1e48..bb8f2a3aa5f7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala @@ -116,7 +116,7 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String Bucketizer.binarySearchForBuckets($(splits), feature, keepInvalid) } - val newCol = bucketizer(filteredDataset($(inputCol))) + val newCol = bucketizer(filteredDataset($(inputCol)).cast(DoubleType)) val newField = prepOutputField(filteredDataset.schema) filteredDataset.withColumn($(outputCol), newCol, newField.metadata) } @@ -130,7 +130,7 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String @Since("1.4.0") override def transformSchema(schema: StructType): StructType = { - SchemaUtils.checkColumnType(schema, $(inputCol), DoubleType) + SchemaUtils.checkNumericType(schema, $(inputCol)) SchemaUtils.appendColumn(schema, prepOutputField(schema)) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala index aac29137d791..420fb17ddce8 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala @@ -26,6 +26,8 @@ import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types._ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { @@ -162,6 +164,29 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa .setSplits(Array(0.1, 0.8, 0.9)) testDefaultReadWrite(t) } + + test("Bucket numeric features") { + val splits = Array(-3.0, 0.0, 3.0) + val data = Array(-2.0, -1.0, 0.0, 1.0, 2.0) + val expectedBuckets = Array(0.0, 0.0, 1.0, 1.0, 1.0) + val dataFrame: DataFrame = data.zip(expectedBuckets).toSeq.toDF("feature", "expected") + + val bucketizer: Bucketizer = new Bucketizer() + .setInputCol("feature") + .setOutputCol("result") + .setSplits(splits) + + val types = Seq(ShortType, IntegerType, LongType, FloatType, DoubleType, + ByteType, DecimalType(10, 0)) + for (mType <- types) { + val df = dataFrame.withColumn("feature", col("feature").cast(mType)) + bucketizer.transform(df).select("result", "expected").collect().foreach { + case Row(x: Double, y: Double) => + assert(x === y, "The result is not correct after bucketing in type " + + mType.toString + ". " + s"Expected $y but found $x.") + } + } + } } private object BucketizerSuite extends SparkFunSuite { From 4411ac70524ced901f7807d492fb0ad2480a8841 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Fri, 5 May 2017 09:50:40 +0100 Subject: [PATCH 426/512] [INFRA] Close stale PRs ## What changes were proposed in this pull request? This PR proposes to close a stale PR, several PRs suggested to be closed by a committer and obviously inappropriate PRs. Closes #11119 Closes #17853 Closes #17732 Closes #17456 Closes #17410 Closes #17314 Closes #17362 Closes #17542 ## How was this patch tested? N/A Author: hyukjinkwon Closes #17855 from HyukjinKwon/close-pr. From 37cdf077cd3f436f777562df311e3827b0727ce7 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Fri, 5 May 2017 11:31:59 +0100 Subject: [PATCH 427/512] [SPARK-19660][SQL] Replace the deprecated property name fs.default.name to fs.defaultFS that newly introduced ## What changes were proposed in this pull request? Replace the deprecated property name `fs.default.name` to `fs.defaultFS` that newly introduced. ## How was this patch tested? Existing tests Author: Yuming Wang Closes #17856 from wangyum/SPARK-19660. --- .../spark/sql/execution/streaming/state/StateStoreSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala index ebb7422765eb..cc09b2d5b776 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala @@ -314,7 +314,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth test("SPARK-19677: Committing a delta file atop an existing one should not fail on HDFS") { val conf = new Configuration() conf.set("fs.fake.impl", classOf[RenameLikeHDFSFileSystem].getName) - conf.set("fs.default.name", "fake:///") + conf.set("fs.defaultFS", "fake:///") val provider = newStoreProvider(hadoopConf = conf) provider.getStore(0).commit() From 5773ab121d5d7cbefeef17ff4ac6f8af36cc1251 Mon Sep 17 00:00:00 2001 From: jyu00 Date: Fri, 5 May 2017 11:36:51 +0100 Subject: [PATCH 428/512] [SPARK-20546][DEPLOY] spark-class gets syntax error in posix mode ## What changes were proposed in this pull request? Updated spark-class to turn off posix mode so the process substitution doesn't cause a syntax error. ## How was this patch tested? Existing unit tests, manual spark-shell testing with posix mode on Author: jyu00 Closes #17852 from jyu00/master. --- bin/spark-class | 2 ++ 1 file changed, 2 insertions(+) diff --git a/bin/spark-class b/bin/spark-class index 77ea40cc3794..65d3b9612909 100755 --- a/bin/spark-class +++ b/bin/spark-class @@ -72,6 +72,8 @@ build_command() { printf "%d\0" $? } +# Turn off posix mode since it does not allow process substitution +set +o posix CMD=() while IFS= read -d '' -r ARG; do CMD+=("$ARG") From 9064f1b04461513a147aeb8179471b05595ddbc4 Mon Sep 17 00:00:00 2001 From: madhu Date: Fri, 5 May 2017 22:44:03 +0800 Subject: [PATCH 429/512] [SPARK-20495][SQL][CORE] Add StorageLevel to cacheTable API ## What changes were proposed in this pull request? Currently cacheTable API only supports MEMORY_AND_DISK. This PR adds additional API to take different storage levels. ## How was this patch tested? unit tests Author: madhu Closes #17802 from phatak-dev/cacheTableAPI. --- project/MimaExcludes.scala | 2 ++ .../org/apache/spark/sql/catalog/Catalog.scala | 14 +++++++++++++- .../apache/spark/sql/internal/CatalogImpl.scala | 13 +++++++++++++ .../apache/spark/sql/internal/CatalogSuite.scala | 8 ++++++++ 4 files changed, 36 insertions(+), 1 deletion(-) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index dbf933f28a78..d50882cb1917 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -36,6 +36,8 @@ object MimaExcludes { // Exclude rules for 2.3.x lazy val v23excludes = v22excludes ++ Seq( + // [SPARK-20495][SQL] Add StorageLevel to cacheTable API + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.cacheTable") ) // Exclude rules for 2.2.x diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala index 7e5da012f84c..ab81725def3f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala @@ -22,7 +22,7 @@ import scala.collection.JavaConverters._ import org.apache.spark.annotation.{Experimental, InterfaceStability} import org.apache.spark.sql.{AnalysisException, DataFrame, Dataset} import org.apache.spark.sql.types.StructType - +import org.apache.spark.storage.StorageLevel /** * Catalog interface for Spark. To access this, use `SparkSession.catalog`. @@ -476,6 +476,18 @@ abstract class Catalog { */ def cacheTable(tableName: String): Unit + /** + * Caches the specified table with the given storage level. + * + * @param tableName is either a qualified or unqualified name that designates a table/view. + * If no database identifier is provided, it refers to a temporary view or + * a table/view in the current database. + * @param storageLevel storage level to cache table. + * @since 2.3.0 + */ + def cacheTable(tableName: String, storageLevel: StorageLevel): Unit + + /** * Removes the specified table from the in-memory cache. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala index 0b8e53868c99..e1049c665a41 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala @@ -30,6 +30,8 @@ import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.execution.command.AlterTableRecoverPartitionsCommand import org.apache.spark.sql.execution.datasources.{CreateTable, DataSource} import org.apache.spark.sql.types.StructType +import org.apache.spark.storage.StorageLevel + /** @@ -419,6 +421,17 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { sparkSession.sharedState.cacheManager.cacheQuery(sparkSession.table(tableName), Some(tableName)) } + /** + * Caches the specified table or view with the given storage level. + * + * @group cachemgmt + * @since 2.3.0 + */ + override def cacheTable(tableName: String, storageLevel: StorageLevel): Unit = { + sparkSession.sharedState.cacheManager.cacheQuery( + sparkSession.table(tableName), Some(tableName), storageLevel) + } + /** * Removes the specified table or view from the in-memory cache. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala index 8f9c52cb1e03..bc641fd280a1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala @@ -30,6 +30,7 @@ import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionInfo} import org.apache.spark.sql.catalyst.plans.logical.Range import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.StructType +import org.apache.spark.storage.StorageLevel /** @@ -535,4 +536,11 @@ class CatalogSuite .createTempView("fork_table", Range(1, 2, 3, 4), overrideIfExists = true) assert(spark.catalog.listTables().collect().map(_.name).toSet == Set()) } + + test("cacheTable with storage level") { + createTempTable("my_temp_table") + spark.catalog.cacheTable("my_temp_table", StorageLevel.DISK_ONLY) + assert(spark.table("my_temp_table").storageLevel == StorageLevel.DISK_ONLY) + } + } From b9ad2d1916af5091c8585d06ccad8219e437e2bc Mon Sep 17 00:00:00 2001 From: Jarrett Meyer Date: Fri, 5 May 2017 08:30:42 -0700 Subject: [PATCH 430/512] [SPARK-20613] Remove excess quotes in Windows executable ## What changes were proposed in this pull request? Quotes are already added to the RUNNER variable on line 54. There is no need to put quotes on line 67. If you do, you will get an error when launching Spark. '""C:\Program' is not recognized as an internal or external command, operable program or batch file. ## How was this patch tested? Tested manually on Windows 10. Author: Jarrett Meyer Closes #17861 from jarrettmeyer/fix-windows-cmd. --- bin/spark-class2.cmd | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bin/spark-class2.cmd b/bin/spark-class2.cmd index 9faa7d65f83e..f6157f42843e 100644 --- a/bin/spark-class2.cmd +++ b/bin/spark-class2.cmd @@ -51,7 +51,7 @@ if not "x%SPARK_PREPEND_CLASSES%"=="x" ( rem Figure out where java is. set RUNNER=java if not "x%JAVA_HOME%"=="x" ( - set RUNNER="%JAVA_HOME%\bin\java" + set RUNNER=%JAVA_HOME%\bin\java ) else ( where /q "%RUNNER%" if ERRORLEVEL 1 ( From 41439fd52dd263b9f7d92e608f027f193f461777 Mon Sep 17 00:00:00 2001 From: Yucai Date: Fri, 5 May 2017 09:51:57 -0700 Subject: [PATCH 431/512] [SPARK-20381][SQL] Add SQL metrics of numOutputRows for ObjectHashAggregateExec ## What changes were proposed in this pull request? ObjectHashAggregateExec is missing numOutputRows, add this metrics for it. ## How was this patch tested? Added unit tests for the new metrics. Author: Yucai Closes #17678 from yucai/objectAgg_numOutputRows. --- .../aggregate/ObjectAggregationIterator.scala | 8 ++++++-- .../aggregate/ObjectHashAggregateExec.scala | 3 ++- .../sql/execution/metric/SQLMetricsSuite.scala | 18 ++++++++++++++++++ 3 files changed, 26 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala index 3a7fcf1fa9d8..6e47f9d61119 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.codegen.{BaseOrdering, GenerateOrdering} import org.apache.spark.sql.execution.UnsafeKVExternalSorter +import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.StructType import org.apache.spark.unsafe.KVIterator @@ -39,7 +40,8 @@ class ObjectAggregationIterator( newMutableProjection: (Seq[Expression], Seq[Attribute]) => MutableProjection, originalInputAttributes: Seq[Attribute], inputRows: Iterator[InternalRow], - fallbackCountThreshold: Int) + fallbackCountThreshold: Int, + numOutputRows: SQLMetric) extends AggregationIterator( groupingExpressions, originalInputAttributes, @@ -83,7 +85,9 @@ class ObjectAggregationIterator( override final def next(): UnsafeRow = { val entry = aggBufferIterator.next() - generateOutput(entry.groupingKey, entry.aggregationBuffer) + val res = generateOutput(entry.groupingKey, entry.aggregationBuffer) + numOutputRows += 1 + res } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala index 3fcb7ec9a641..b53521b1b6ba 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala @@ -117,7 +117,8 @@ case class ObjectHashAggregateExec( newMutableProjection(expressions, inputSchema, subexpressionEliminationEnabled), child.output, iter, - fallbackCountThreshold) + fallbackCountThreshold, + numOutputRows) if (!hasInput && groupingExpressions.isEmpty) { numOutputRows += 1 Iterator.single[UnsafeRow](aggregationIterator.outputForEmptyGroupingKeyWithoutInput()) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index 2ce7db6a22c0..e544245588f4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -143,6 +143,24 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { ) } + test("ObjectHashAggregate metrics") { + // Assume the execution plan is + // ... -> ObjectHashAggregate(nodeId = 2) -> Exchange(nodeId = 1) + // -> ObjectHashAggregate(nodeId = 0) + val df = testData2.groupBy().agg(collect_set('a)) // 2 partitions + testSparkPlanMetrics(df, 1, Map( + 2L -> ("ObjectHashAggregate", Map("number of output rows" -> 2L)), + 0L -> ("ObjectHashAggregate", Map("number of output rows" -> 1L))) + ) + + // 2 partitions and each partition contains 2 keys + val df2 = testData2.groupBy('a).agg(collect_set('a)) + testSparkPlanMetrics(df2, 1, Map( + 2L -> ("ObjectHashAggregate", Map("number of output rows" -> 4L)), + 0L -> ("ObjectHashAggregate", Map("number of output rows" -> 3L))) + ) + } + test("Sort metrics") { // Assume the execution plan is // WholeStageCodegen(nodeId = 0, Range(nodeId = 2) -> Sort(nodeId = 1)) From bd5788287957d8610a6d19c273b75bd4cdd2d166 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Fri, 5 May 2017 11:08:26 -0700 Subject: [PATCH 432/512] [SPARK-20603][SS][TEST] Set default number of topic partitions to 1 to reduce the load ## What changes were proposed in this pull request? I checked the logs of https://amplab.cs.berkeley.edu/jenkins/job/spark-branch-2.2-test-maven-hadoop-2.7/47/ and found it took several seconds to create Kafka internal topic `__consumer_offsets`. As Kafka creates this topic lazily, the topic creation happens in the first test `deserialization of initial offset with Spark 2.1.0` and causes it timeout. This PR changes `offsets.topic.num.partitions` from the default value 50 to 1 to make creating `__consumer_offsets` (50 partitions -> 1 partition) much faster. ## How was this patch tested? Jenkins Author: Shixiong Zhu Closes #17863 from zsxwing/fix-kafka-flaky-test. --- .../scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala index 2ce2760b7f46..f86b8f586d2a 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala @@ -292,6 +292,7 @@ class KafkaTestUtils(withBrokerProps: Map[String, Object] = Map.empty) extends L props.put("log.flush.interval.messages", "1") props.put("replica.socket.timeout.ms", "1500") props.put("delete.topic.enable", "true") + props.put("offsets.topic.num.partitions", "1") props.putAll(withBrokerProps.asJava) props } From b31648c081e8db34e0d6c71875318f7b0b047c8b Mon Sep 17 00:00:00 2001 From: Jannik Arndt Date: Fri, 5 May 2017 11:42:55 -0700 Subject: [PATCH 433/512] [SPARK-20557][SQL] Support for db column type TIMESTAMP WITH TIME ZONE ## What changes were proposed in this pull request? SparkSQL can now read from a database table with column type [TIMESTAMP WITH TIME ZONE](https://docs.oracle.com/javase/8/docs/api/java/sql/Types.html#TIMESTAMP_WITH_TIMEZONE). ## How was this patch tested? Tested against Oracle database. JoshRosen, you seem to know the class, would you look at this? Thanks! Author: Jannik Arndt Closes #17832 from JannikArndt/spark-20557-timestamp-with-timezone. --- .../spark/sql/jdbc/OracleIntegrationSuite.scala | 13 +++++++++++++ .../sql/execution/datasources/jdbc/JdbcUtils.scala | 3 +++ 2 files changed, 16 insertions(+) diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala index 1bb89a361ca7..85d4a4a791e6 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala @@ -70,6 +70,12 @@ class OracleIntegrationSuite extends DockerJDBCIntegrationSuite with SharedSQLCo """.stripMargin.replaceAll("\n", " ")).executeUpdate() conn.commit() + conn.prepareStatement("CREATE TABLE ts_with_timezone (id NUMBER(10), t TIMESTAMP WITH TIME ZONE)") + .executeUpdate() + conn.prepareStatement("INSERT INTO ts_with_timezone VALUES (1, to_timestamp_tz('1999-12-01 11:00:00 UTC','YYYY-MM-DD HH:MI:SS TZR'))") + .executeUpdate() + conn.commit() + sql( s""" |CREATE TEMPORARY VIEW datetime @@ -185,4 +191,11 @@ class OracleIntegrationSuite extends DockerJDBCIntegrationSuite with SharedSQLCo sql("INSERT INTO TABLE datetime1 SELECT * FROM datetime where id = 1") checkRow(sql("SELECT * FROM datetime1 where id = 1").head()) } + + test("SPARK-20557: column type TIMEZONE with TIME STAMP should be recognized") { + val dfRead = sqlContext.read.jdbc(jdbcUrl, "ts_with_timezone", new Properties) + val rows = dfRead.collect() + val types = rows(0).toSeq.map(x => x.getClass.toString) + assert(types(1).equals("class java.sql.Timestamp")) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index 0183805d5625..fb877d1ca763 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -223,6 +223,9 @@ object JdbcUtils extends Logging { case java.sql.Types.STRUCT => StringType case java.sql.Types.TIME => TimestampType case java.sql.Types.TIMESTAMP => TimestampType + case java.sql.Types.TIMESTAMP_WITH_TIMEZONE + => TimestampType + case -101 => TimestampType // Value for Timestamp with Time Zone in Oracle case java.sql.Types.TINYINT => IntegerType case java.sql.Types.VARBINARY => BinaryType case java.sql.Types.VARCHAR => StringType From 5d75b14bf0f4c1f0813287efaabf49797908ed55 Mon Sep 17 00:00:00 2001 From: Juliusz Sompolski Date: Fri, 5 May 2017 15:31:06 -0700 Subject: [PATCH 434/512] [SPARK-20616] RuleExecutor logDebug of batch results should show diff to start of batch ## What changes were proposed in this pull request? Due to a likely typo, the logDebug msg printing the diff of query plans shows a diff to the initial plan, not diff to the start of batch. ## How was this patch tested? Now the debug message prints the diff between start and end of batch. Author: Juliusz Sompolski Closes #17875 from juliuszsompolski/SPARK-20616. --- .../org/apache/spark/sql/catalyst/rules/RuleExecutor.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala index 6fc828f63f15..85b368c86263 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala @@ -122,7 +122,7 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging { logDebug( s""" |=== Result of Batch ${batch.name} === - |${sideBySide(plan.treeString, curPlan.treeString).mkString("\n")} + |${sideBySide(batchStartPlan.treeString, curPlan.treeString).mkString("\n")} """.stripMargin) } else { logTrace(s"Batch ${batch.name} has no effect.") From b433acae74887e59f2e237a6284a4ae04fbbe854 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Fri, 5 May 2017 21:26:55 -0700 Subject: [PATCH 435/512] [SPARK-20614][PROJECT INFRA] Use the same log4j configuration with Jenkins in AppVeyor ## What changes were proposed in this pull request? Currently, there are flooding logs in AppVeyor (in the console). This has been fine because we can download all the logs. However, (given my observations so far), logs are truncated when there are too many. It has been grown recently and it started to get truncated. For example, see https://ci.appveyor.com/project/ApacheSoftwareFoundation/spark/build/1209-master Even after the log is downloaded, it looks truncated as below: ``` [00:44:21] 17/05/04 18:56:18 INFO TaskSetManager: Finished task 197.0 in stage 601.0 (TID 9211) in 0 ms on localhost (executor driver) (194/200) [00:44:21] 17/05/04 18:56:18 INFO Executor: Running task 199.0 in stage 601.0 (TID 9213) [00:44:21] 17/05/04 18:56:18 INFO Executor: Finished task 198.0 in stage 601.0 (TID 9212). 2473 bytes result sent to driver ... ``` Probably, it looks better to use the same log4j configuration that we are using for SparkR tests in Jenkins(please see https://github.com/apache/spark/blob/fc472bddd1d9c6a28e57e31496c0166777af597e/R/run-tests.sh#L26 and https://github.com/apache/spark/blob/fc472bddd1d9c6a28e57e31496c0166777af597e/R/log4j.properties) ``` # Set everything to be logged to the file target/unit-tests.log log4j.rootCategory=INFO, file log4j.appender.file=org.apache.log4j.FileAppender log4j.appender.file.append=true log4j.appender.file.file=R/target/unit-tests.log log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose log4j.logger.org.eclipse.jetty=WARN org.eclipse.jetty.LEVEL=WARN ``` ## How was this patch tested? Manually tested with spark-test account - https://ci.appveyor.com/project/spark-test/spark/build/672-r-log4j (there is an example for flaky test here) - https://ci.appveyor.com/project/spark-test/spark/build/673-r-log4j (I re-ran the build). Author: hyukjinkwon Closes #17873 from HyukjinKwon/appveyor-reduce-logs. --- appveyor.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/appveyor.yml b/appveyor.yml index bbb27589cad0..4d31af70f056 100644 --- a/appveyor.yml +++ b/appveyor.yml @@ -49,7 +49,7 @@ build_script: - cmd: mvn -DskipTests -Psparkr -Phive -Phive-thriftserver package test_script: - - cmd: .\bin\spark-submit2.cmd --conf spark.hadoop.fs.defaultFS="file:///" R\pkg\tests\run-all.R + - cmd: .\bin\spark-submit2.cmd --driver-java-options "-Dlog4j.configuration=file:///%CD:\=/%/R/log4j.properties" --conf spark.hadoop.fs.defaultFS="file:///" R\pkg\tests\run-all.R notifications: - provider: Email From cafca54c0ea8bd9c3b80dcbc88d9f2b8d708a026 Mon Sep 17 00:00:00 2001 From: Xiao Li Date: Sat, 6 May 2017 22:21:19 -0700 Subject: [PATCH 436/512] [SPARK-20557][SQL] Support JDBC data type Time with Time Zone ### What changes were proposed in this pull request? This PR is to support JDBC data type TIME WITH TIME ZONE. It can be converted to TIMESTAMP In addition, before this PR, for unsupported data types, we simply output the type number instead of the type name. ``` java.sql.SQLException: Unsupported type 2014 ``` After this PR, the message is like ``` java.sql.SQLException: Unsupported type TIMESTAMP_WITH_TIMEZONE ``` - Also upgrade the H2 version to `1.4.195` which has the type fix for "TIMESTAMP WITH TIMEZONE". However, it is not fully supported. Thus, we capture the exception, but we still need it to partially test the support of "TIMESTAMP WITH TIMEZONE", because Docker tests are not regularly run. ### How was this patch tested? Added test cases. Author: Xiao Li Closes #17835 from gatorsmile/h2. --- .../sql/jdbc/OracleIntegrationSuite.scala | 2 +- .../sql/jdbc/PostgresIntegrationSuite.scala | 15 ++++++++++++ sql/core/pom.xml | 2 +- .../datasources/jdbc/JdbcUtils.scala | 12 +++++++--- .../spark/sql/internal/CatalogImpl.scala | 1 - .../org/apache/spark/sql/jdbc/JDBCSuite.scala | 24 +++++++++++++++++-- 6 files changed, 48 insertions(+), 8 deletions(-) diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala index 85d4a4a791e6..f7b1ec34ced7 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala @@ -192,7 +192,7 @@ class OracleIntegrationSuite extends DockerJDBCIntegrationSuite with SharedSQLCo checkRow(sql("SELECT * FROM datetime1 where id = 1").head()) } - test("SPARK-20557: column type TIMEZONE with TIME STAMP should be recognized") { + test("SPARK-20557: column type TIMESTAMP with TIME ZONE should be recognized") { val dfRead = sqlContext.read.jdbc(jdbcUrl, "ts_with_timezone", new Properties) val rows = dfRead.collect() val types = rows(0).toSeq.map(x => x.getClass.toString) diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala index a1a065a443e6..eb3c458360e7 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala @@ -55,6 +55,13 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite { + "null, null, null, null, null, " + "null, null, null, null, null, null, null)" ).executeUpdate() + + conn.prepareStatement("CREATE TABLE ts_with_timezone " + + "(id integer, tstz TIMESTAMP WITH TIME ZONE, ttz TIME WITH TIME ZONE)") + .executeUpdate() + conn.prepareStatement("INSERT INTO ts_with_timezone VALUES " + + "(1, TIMESTAMP WITH TIME ZONE '2016-08-12 10:22:31.949271-07', TIME WITH TIME ZONE '17:22:31.949271+00')") + .executeUpdate() } test("Type mapping for various types") { @@ -126,4 +133,12 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite { assert(schema(0).dataType == FloatType) assert(schema(1).dataType == ShortType) } + + test("SPARK-20557: column type TIMESTAMP with TIME ZONE and TIME with TIME ZONE should be recognized") { + val dfRead = sqlContext.read.jdbc(jdbcUrl, "ts_with_timezone", new Properties) + val rows = dfRead.collect() + val types = rows(0).toSeq.map(x => x.getClass.toString) + assert(types(1).equals("class java.sql.Timestamp")) + assert(types(2).equals("class java.sql.Timestamp")) + } } diff --git a/sql/core/pom.xml b/sql/core/pom.xml index e170133f0f0b..fe4be963e818 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -115,7 +115,7 @@ com.h2database h2 - 1.4.183 + 1.4.195 test diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index fb877d1ca763..71eaab119d75 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution.datasources.jdbc -import java.sql.{Connection, Driver, DriverManager, PreparedStatement, ResultSet, ResultSetMetaData, SQLException} +import java.sql.{Connection, Driver, DriverManager, JDBCType, PreparedStatement, ResultSet, ResultSetMetaData, SQLException} import java.util.Locale import scala.collection.JavaConverters._ @@ -217,11 +217,14 @@ object JdbcUtils extends Logging { case java.sql.Types.OTHER => null case java.sql.Types.REAL => DoubleType case java.sql.Types.REF => StringType + case java.sql.Types.REF_CURSOR => null case java.sql.Types.ROWID => LongType case java.sql.Types.SMALLINT => IntegerType case java.sql.Types.SQLXML => StringType case java.sql.Types.STRUCT => StringType case java.sql.Types.TIME => TimestampType + case java.sql.Types.TIME_WITH_TIMEZONE + => TimestampType case java.sql.Types.TIMESTAMP => TimestampType case java.sql.Types.TIMESTAMP_WITH_TIMEZONE => TimestampType @@ -229,11 +232,14 @@ object JdbcUtils extends Logging { case java.sql.Types.TINYINT => IntegerType case java.sql.Types.VARBINARY => BinaryType case java.sql.Types.VARCHAR => StringType - case _ => null + case _ => + throw new SQLException("Unrecognized SQL type " + sqlType) // scalastyle:on } - if (answer == null) throw new SQLException("Unsupported type " + sqlType) + if (answer == null) { + throw new SQLException("Unsupported type " + JDBCType.valueOf(sqlType).getName) + } answer } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala index e1049c665a41..142b005850a4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala @@ -33,7 +33,6 @@ import org.apache.spark.sql.types.StructType import org.apache.spark.storage.StorageLevel - /** * Internal implementation of the user-facing `Catalog`. */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 5bd36ec25ccb..d9f3689411ab 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -18,13 +18,13 @@ package org.apache.spark.sql.jdbc import java.math.BigDecimal -import java.sql.{Date, DriverManager, Timestamp} +import java.sql.{Date, DriverManager, SQLException, Timestamp} import java.util.{Calendar, GregorianCalendar, Properties} import org.h2.jdbc.JdbcSQLException import org.scalatest.{BeforeAndAfter, PrivateMethodTester} -import org.apache.spark.SparkFunSuite +import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.sql.{AnalysisException, DataFrame, Row} import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.execution.DataSourceScanExec @@ -141,6 +141,15 @@ class JDBCSuite extends SparkFunSuite |OPTIONS (url '$url', dbtable 'TEST.TIMETYPES', user 'testUser', password 'testPass') """.stripMargin.replaceAll("\n", " ")) + conn.prepareStatement("CREATE TABLE test.timezone (tz TIMESTAMP WITH TIME ZONE) " + + "AS SELECT '1999-01-08 04:05:06.543543543 GMT-08:00'") + .executeUpdate() + conn.commit() + + conn.prepareStatement("CREATE TABLE test.array (ar ARRAY) " + + "AS SELECT '(1, 2, 3)'") + .executeUpdate() + conn.commit() conn.prepareStatement("create table test.flttypes (a DOUBLE, b REAL, c DECIMAL(38, 18))" ).executeUpdate() @@ -919,6 +928,17 @@ class JDBCSuite extends SparkFunSuite assert(res === (foobarCnt, 0L, foobarCnt) :: Nil) } + test("unsupported types") { + var e = intercept[SparkException] { + spark.read.jdbc(urlWithUserAndPass, "TEST.TIMEZONE", new Properties()).collect() + }.getMessage + assert(e.contains("java.lang.UnsupportedOperationException: unimplemented")) + e = intercept[SQLException] { + spark.read.jdbc(urlWithUserAndPass, "TEST.ARRAY", new Properties()).collect() + }.getMessage + assert(e.contains("Unsupported type ARRAY")) + } + test("SPARK-19318: Connection properties keys should be case-sensitive.") { def testJdbcOptions(options: JDBCOptions): Unit = { // Spark JDBC data source options are case-insensitive From 63d90e7da4913917982c0501d63ccc433a9b6b46 Mon Sep 17 00:00:00 2001 From: zero323 Date: Sat, 6 May 2017 22:28:42 -0700 Subject: [PATCH 437/512] [SPARK-18777][PYTHON][SQL] Return UDF from udf.register ## What changes were proposed in this pull request? - Move udf wrapping code from `functions.udf` to `functions.UserDefinedFunction`. - Return wrapped udf from `catalog.registerFunction` and dependent methods. - Update docstrings in `catalog.registerFunction` and `SQLContext.registerFunction`. - Unit tests. ## How was this patch tested? - Existing unit tests and docstests. - Additional tests covering new feature. Author: zero323 Closes #17831 from zero323/SPARK-18777. --- python/pyspark/sql/catalog.py | 11 ++++++++--- python/pyspark/sql/context.py | 12 ++++++++---- python/pyspark/sql/functions.py | 23 ++++++++++++++--------- python/pyspark/sql/tests.py | 9 +++++++++ 4 files changed, 39 insertions(+), 16 deletions(-) diff --git a/python/pyspark/sql/catalog.py b/python/pyspark/sql/catalog.py index 41e68a45a615..5f25dce16196 100644 --- a/python/pyspark/sql/catalog.py +++ b/python/pyspark/sql/catalog.py @@ -237,23 +237,28 @@ def registerFunction(self, name, f, returnType=StringType()): :param name: name of the UDF :param f: python function :param returnType: a :class:`pyspark.sql.types.DataType` object + :return: a wrapped :class:`UserDefinedFunction` - >>> spark.catalog.registerFunction("stringLengthString", lambda x: len(x)) + >>> strlen = spark.catalog.registerFunction("stringLengthString", len) >>> spark.sql("SELECT stringLengthString('test')").collect() [Row(stringLengthString(test)=u'4')] + >>> spark.sql("SELECT 'foo' AS text").select(strlen("text")).collect() + [Row(stringLengthString(text)=u'3')] + >>> from pyspark.sql.types import IntegerType - >>> spark.catalog.registerFunction("stringLengthInt", lambda x: len(x), IntegerType()) + >>> _ = spark.catalog.registerFunction("stringLengthInt", len, IntegerType()) >>> spark.sql("SELECT stringLengthInt('test')").collect() [Row(stringLengthInt(test)=4)] >>> from pyspark.sql.types import IntegerType - >>> spark.udf.register("stringLengthInt", lambda x: len(x), IntegerType()) + >>> _ = spark.udf.register("stringLengthInt", len, IntegerType()) >>> spark.sql("SELECT stringLengthInt('test')").collect() [Row(stringLengthInt(test)=4)] """ udf = UserDefinedFunction(f, returnType, name) self._jsparkSession.udf().registerPython(name, udf._judf) + return udf._wrapped() @since(2.0) def isCached(self, tableName): diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index fdb7abbad4e5..5197a9e00461 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -185,22 +185,26 @@ def registerFunction(self, name, f, returnType=StringType()): :param name: name of the UDF :param f: python function :param returnType: a :class:`pyspark.sql.types.DataType` object + :return: a wrapped :class:`UserDefinedFunction` - >>> sqlContext.registerFunction("stringLengthString", lambda x: len(x)) + >>> strlen = sqlContext.registerFunction("stringLengthString", lambda x: len(x)) >>> sqlContext.sql("SELECT stringLengthString('test')").collect() [Row(stringLengthString(test)=u'4')] + >>> sqlContext.sql("SELECT 'foo' AS text").select(strlen("text")).collect() + [Row(stringLengthString(text)=u'3')] + >>> from pyspark.sql.types import IntegerType - >>> sqlContext.registerFunction("stringLengthInt", lambda x: len(x), IntegerType()) + >>> _ = sqlContext.registerFunction("stringLengthInt", lambda x: len(x), IntegerType()) >>> sqlContext.sql("SELECT stringLengthInt('test')").collect() [Row(stringLengthInt(test)=4)] >>> from pyspark.sql.types import IntegerType - >>> sqlContext.udf.register("stringLengthInt", lambda x: len(x), IntegerType()) + >>> _ = sqlContext.udf.register("stringLengthInt", lambda x: len(x), IntegerType()) >>> sqlContext.sql("SELECT stringLengthInt('test')").collect() [Row(stringLengthInt(test)=4)] """ - self.sparkSession.catalog.registerFunction(name, f, returnType) + return self.sparkSession.catalog.registerFunction(name, f, returnType) @ignore_unicode_prefix @since(2.1) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 843ae3816f06..8b3487c3f108 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1917,6 +1917,19 @@ def __call__(self, *cols): sc = SparkContext._active_spark_context return Column(judf.apply(_to_seq(sc, cols, _to_java_column))) + def _wrapped(self): + """ + Wrap this udf with a function and attach docstring from func + """ + @functools.wraps(self.func) + def wrapper(*args): + return self(*args) + + wrapper.func = self.func + wrapper.returnType = self.returnType + + return wrapper + @since(1.3) def udf(f=None, returnType=StringType()): @@ -1951,15 +1964,7 @@ def udf(f=None, returnType=StringType()): """ def _udf(f, returnType=StringType()): udf_obj = UserDefinedFunction(f, returnType) - - @functools.wraps(f) - def wrapper(*args): - return udf_obj(*args) - - wrapper.func = udf_obj.func - wrapper.returnType = udf_obj.returnType - - return wrapper + return udf_obj._wrapped() # decorator @udf, @udf() or @udf(dataType()) if f is None or isinstance(f, (str, DataType)): diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index f644624f7f31..7983bc536fc6 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -436,6 +436,15 @@ def test_udf_with_order_by_and_limit(self): res.explain(True) self.assertEqual(res.collect(), [Row(id=0, copy=0)]) + def test_udf_registration_returns_udf(self): + df = self.spark.range(10) + add_three = self.spark.udf.register("add_three", lambda x: x + 3, IntegerType()) + + self.assertListEqual( + df.selectExpr("add_three(id) AS plus_three").collect(), + df.select(add_three("id").alias("plus_three")).collect() + ) + def test_wholefile_json(self): people1 = self.spark.read.json("python/test_support/sql/people.json") people_array = self.spark.read.json("python/test_support/sql/people_array.json", From 37f963ac13ec1bd958c44c7c15b5e8cb6c06cbbc Mon Sep 17 00:00:00 2001 From: caoxuewen Date: Sun, 7 May 2017 10:08:06 +0100 Subject: [PATCH 438/512] [SPARK-20518][CORE] Supplement the new blockidsuite unit tests ## What changes were proposed in this pull request? This PR adds the new unit tests to support ShuffleDataBlockId , ShuffleIndexBlockId , TempShuffleBlockId , TempLocalBlockId ## How was this patch tested? The new unit test. Author: caoxuewen Closes #17794 from heary-cao/blockidsuite. --- .../apache/spark/storage/BlockIdSuite.scala | 52 +++++++++++++++++++ 1 file changed, 52 insertions(+) diff --git a/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala index 89ed031b6fcd..f0c521b00b58 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.storage +import java.util.UUID + import org.apache.spark.SparkFunSuite class BlockIdSuite extends SparkFunSuite { @@ -67,6 +69,32 @@ class BlockIdSuite extends SparkFunSuite { assertSame(id, BlockId(id.toString)) } + test("shuffle data") { + val id = ShuffleDataBlockId(4, 5, 6) + assertSame(id, ShuffleDataBlockId(4, 5, 6)) + assertDifferent(id, ShuffleDataBlockId(6, 5, 6)) + assert(id.name === "shuffle_4_5_6.data") + assert(id.asRDDId === None) + assert(id.shuffleId === 4) + assert(id.mapId === 5) + assert(id.reduceId === 6) + assert(!id.isShuffle) + assertSame(id, BlockId(id.toString)) + } + + test("shuffle index") { + val id = ShuffleIndexBlockId(7, 8, 9) + assertSame(id, ShuffleIndexBlockId(7, 8, 9)) + assertDifferent(id, ShuffleIndexBlockId(9, 8, 9)) + assert(id.name === "shuffle_7_8_9.index") + assert(id.asRDDId === None) + assert(id.shuffleId === 7) + assert(id.mapId === 8) + assert(id.reduceId === 9) + assert(!id.isShuffle) + assertSame(id, BlockId(id.toString)) + } + test("broadcast") { val id = BroadcastBlockId(42) assertSame(id, BroadcastBlockId(42)) @@ -101,6 +129,30 @@ class BlockIdSuite extends SparkFunSuite { assertSame(id, BlockId(id.toString)) } + test("temp local") { + val id = TempLocalBlockId(new UUID(5, 2)) + assertSame(id, TempLocalBlockId(new UUID(5, 2))) + assertDifferent(id, TempLocalBlockId(new UUID(5, 3))) + assert(id.name === "temp_local_00000000-0000-0005-0000-000000000002") + assert(id.asRDDId === None) + assert(id.isBroadcast === false) + assert(id.id.getMostSignificantBits() === 5) + assert(id.id.getLeastSignificantBits() === 2) + assert(!id.isShuffle) + } + + test("temp shuffle") { + val id = TempShuffleBlockId(new UUID(1, 2)) + assertSame(id, TempShuffleBlockId(new UUID(1, 2))) + assertDifferent(id, TempShuffleBlockId(new UUID(1, 3))) + assert(id.name === "temp_shuffle_00000000-0000-0001-0000-000000000002") + assert(id.asRDDId === None) + assert(id.isBroadcast === false) + assert(id.id.getMostSignificantBits() === 1) + assert(id.id.getLeastSignificantBits() === 2) + assert(!id.isShuffle) + } + test("test") { val id = TestBlockId("abc") assertSame(id, TestBlockId("abc")) From 88e6d75072c23fa99d4df00d087d03d8c38e8c69 Mon Sep 17 00:00:00 2001 From: Daniel Li Date: Sun, 7 May 2017 10:09:58 +0100 Subject: [PATCH 439/512] [SPARK-20484][MLLIB] Add documentation to ALS code MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? This PR adds documentation to the ALS code. ## How was this patch tested? Existing tests were used. mengxr srowen This contribution is my original work. I have the license to work on this project under the Spark project’s open source license. Author: Daniel Li Closes #17793 from danielyli/spark-20484. --- .../apache/spark/ml/recommendation/ALS.scala | 236 +++++++++++++++--- 1 file changed, 202 insertions(+), 34 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index a20ef7244666..1562bf1beb7e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -774,6 +774,28 @@ object ALS extends DefaultParamsReadable[ALS] with Logging { /** * :: DeveloperApi :: * Implementation of the ALS algorithm. + * + * This implementation of the ALS factorization algorithm partitions the two sets of factors among + * Spark workers so as to reduce network communication by only sending one copy of each factor + * vector to each Spark worker on each iteration, and only if needed. This is achieved by + * precomputing some information about the ratings matrix to determine which users require which + * item factors and vice versa. See the Scaladoc for `InBlock` for a detailed explanation of how + * the precomputation is done. + * + * In addition, since each iteration of calculating the factor matrices depends on the known + * ratings, which are spread across Spark partitions, a naive implementation would incur + * significant network communication overhead between Spark workers, as the ratings RDD would be + * repeatedly shuffled during each iteration. This implementation reduces that overhead by + * performing the shuffling operation up front, precomputing each partition's ratings dependencies + * and duplicating those values to the appropriate workers before starting iterations to solve for + * the factor matrices. See the Scaladoc for `OutBlock` for a detailed explanation of how the + * precomputation is done. + * + * Note that the term "rating block" is a bit of a misnomer, as the ratings are not partitioned by + * contiguous blocks from the ratings matrix but by a hash function on the rating's location in + * the matrix. If it helps you to visualize the partitions, it is easier to think of the term + * "block" as referring to a subset of an RDD containing the ratings rather than a contiguous + * submatrix of the ratings matrix. */ @DeveloperApi def train[ID: ClassTag]( // scalastyle:ignore @@ -791,32 +813,43 @@ object ALS extends DefaultParamsReadable[ALS] with Logging { checkpointInterval: Int = 10, seed: Long = 0L)( implicit ord: Ordering[ID]): (RDD[(ID, Array[Float])], RDD[(ID, Array[Float])]) = { + require(!ratings.isEmpty(), s"No ratings available from $ratings") require(intermediateRDDStorageLevel != StorageLevel.NONE, "ALS is not designed to run without persisting intermediate RDDs.") + val sc = ratings.sparkContext + + // Precompute the rating dependencies of each partition val userPart = new ALSPartitioner(numUserBlocks) val itemPart = new ALSPartitioner(numItemBlocks) - val userLocalIndexEncoder = new LocalIndexEncoder(userPart.numPartitions) - val itemLocalIndexEncoder = new LocalIndexEncoder(itemPart.numPartitions) - val solver = if (nonnegative) new NNLSSolver else new CholeskySolver val blockRatings = partitionRatings(ratings, userPart, itemPart) .persist(intermediateRDDStorageLevel) val (userInBlocks, userOutBlocks) = makeBlocks("user", blockRatings, userPart, itemPart, intermediateRDDStorageLevel) - // materialize blockRatings and user blocks - userOutBlocks.count() + userOutBlocks.count() // materialize blockRatings and user blocks val swappedBlockRatings = blockRatings.map { case ((userBlockId, itemBlockId), RatingBlock(userIds, itemIds, localRatings)) => ((itemBlockId, userBlockId), RatingBlock(itemIds, userIds, localRatings)) } val (itemInBlocks, itemOutBlocks) = makeBlocks("item", swappedBlockRatings, itemPart, userPart, intermediateRDDStorageLevel) - // materialize item blocks - itemOutBlocks.count() + itemOutBlocks.count() // materialize item blocks + + // Encoders for storing each user/item's partition ID and index within its partition using a + // single integer; used as an optimization + val userLocalIndexEncoder = new LocalIndexEncoder(userPart.numPartitions) + val itemLocalIndexEncoder = new LocalIndexEncoder(itemPart.numPartitions) + + // These are the user and item factor matrices that, once trained, are multiplied together to + // estimate the rating matrix. The two matrices are stored in RDDs, partitioned by column such + // that each factor column resides on the same Spark worker as its corresponding user or item. val seedGen = new XORShiftRandom(seed) var userFactors = initialize(userInBlocks, rank, seedGen.nextLong()) var itemFactors = initialize(itemInBlocks, rank, seedGen.nextLong()) + + val solver = if (nonnegative) new NNLSSolver else new CholeskySolver + var previousCheckpointFile: Option[String] = None val shouldCheckpoint: Int => Boolean = (iter) => sc.checkpointDir.isDefined && checkpointInterval != -1 && (iter % checkpointInterval == 0) @@ -830,6 +863,7 @@ object ALS extends DefaultParamsReadable[ALS] with Logging { logWarning(s"Cannot delete checkpoint file $file:", e) } } + if (implicitPrefs) { for (iter <- 1 to maxIter) { userFactors.setName(s"userFactors-$iter").persist(intermediateRDDStorageLevel) @@ -910,26 +944,154 @@ object ALS extends DefaultParamsReadable[ALS] with Logging { private type FactorBlock = Array[Array[Float]] /** - * Out-link block that stores, for each dst (item/user) block, which src (user/item) factors to - * send. For example, outLinkBlock(0) contains the local indices (not the original src IDs) of the - * src factors in this block to send to dst block 0. + * A mapping of the columns of the items factor matrix that are needed when calculating each row + * of the users factor matrix, and vice versa. + * + * Specifically, when calculating a user factor vector, since only those columns of the items + * factor matrix that correspond to the items that that user has rated are needed, we can avoid + * having to repeatedly copy the entire items factor matrix to each worker later in the algorithm + * by precomputing these dependencies for all users, storing them in an RDD of `OutBlock`s. The + * items' dependencies on the columns of the users factor matrix is computed similarly. + * + * =Example= + * + * Using the example provided in the `InBlock` Scaladoc, `userOutBlocks` would look like the + * following: + * + * {{{ + * userOutBlocks.collect() == Seq( + * 0 -> Array(Array(0, 1), Array(0, 1)), + * 1 -> Array(Array(0), Array(0)) + * ) + * }}} + * + * Each value in this map-like sequence is of type `Array[Array[Int]]`. The values in the + * inner array are the ranks of the sorted user IDs in that partition; so in the example above, + * `Array(0, 1)` in partition 0 refers to user IDs 0 and 6, since when all unique user IDs in + * partition 0 are sorted, 0 is the first ID and 6 is the second. The position of each inner + * array in its enclosing outer array denotes the partition number to which item IDs map; in the + * example, the first `Array(0, 1)` is in position 0 of its outer array, denoting item IDs that + * map to partition 0. + * + * In summary, the data structure encodes the following information: + * + * * There are ratings with user IDs 0 and 6 (encoded in `Array(0, 1)`, where 0 and 1 are the + * indices of the user IDs 0 and 6 on partition 0) whose item IDs map to partitions 0 and 1 + * (represented by the fact that `Array(0, 1)` appears in both the 0th and 1st positions). + * + * * There are ratings with user ID 3 (encoded in `Array(0)`, where 0 is the index of the user + * ID 3 on partition 1) whose item IDs map to partitions 0 and 1 (represented by the fact that + * `Array(0)` appears in both the 0th and 1st positions). */ private type OutBlock = Array[Array[Int]] /** - * In-link block for computing src (user/item) factors. This includes the original src IDs - * of the elements within this block as well as encoded dst (item/user) indices and corresponding - * ratings. The dst indices are in the form of (blockId, localIndex), which are not the original - * dst IDs. To compute src factors, we expect receiving dst factors that match the dst indices. - * For example, if we have an in-link record + * In-link block for computing user and item factor matrices. + * + * The ALS algorithm partitions the columns of the users factor matrix evenly among Spark workers. + * Since each column of the factor matrix is calculated using the known ratings of the correspond- + * ing user, and since the ratings don't change across iterations, the ALS algorithm preshuffles + * the ratings to the appropriate partitions, storing them in `InBlock` objects. + * + * The ratings shuffled by item ID are computed similarly and also stored in `InBlock` objects. + * Note that this means every rating is stored twice, once as shuffled by user ID and once by item + * ID. This is a necessary tradeoff, since in general a rating will not be on the same worker + * when partitioned by user as by item. + * + * =Example= + * + * Say we have a small collection of eight items to offer the seven users in our application. We + * have some known ratings given by the users, as seen in the matrix below: + * + * {{{ + * Items + * 0 1 2 3 4 5 6 7 + * +---+---+---+---+---+---+---+---+ + * 0 | |0.1| | |0.4| | |0.7| + * +---+---+---+---+---+---+---+---+ + * 1 | | | | | | | | | + * +---+---+---+---+---+---+---+---+ + * U 2 | | | | | | | | | + * s +---+---+---+---+---+---+---+---+ + * e 3 | |3.1| | |3.4| | |3.7| + * r +---+---+---+---+---+---+---+---+ + * s 4 | | | | | | | | | + * +---+---+---+---+---+---+---+---+ + * 5 | | | | | | | | | + * +---+---+---+---+---+---+---+---+ + * 6 | |6.1| | |6.4| | |6.7| + * +---+---+---+---+---+---+---+---+ + * }}} + * + * The ratings are represented as an RDD, passed to the `partitionRatings` method as the `ratings` + * parameter: + * + * {{{ + * ratings.collect() == Seq( + * Rating(0, 1, 0.1f), + * Rating(0, 4, 0.4f), + * Rating(0, 7, 0.7f), + * Rating(3, 1, 3.1f), + * Rating(3, 4, 3.4f), + * Rating(3, 7, 3.7f), + * Rating(6, 1, 6.1f), + * Rating(6, 4, 6.4f), + * Rating(6, 7, 6.7f) + * ) + * }}} * - * {srcId: 0, dstBlockId: 2, dstLocalIndex: 3, rating: 5.0}, + * Say that we are using two partitions to calculate each factor matrix: * - * and assume that the dst factors are stored as dstFactors: Map[Int, Array[Array[Float]]], which - * is a blockId to dst factors map, the corresponding dst factor of the record is dstFactor(2)(3). + * {{{ + * val userPart = new ALSPartitioner(2) + * val itemPart = new ALSPartitioner(2) + * val blockRatings = partitionRatings(ratings, userPart, itemPart) + * }}} * - * We use a CSC-like (compressed sparse column) format to store the in-link information. So we can - * compute src factors one after another using only one normal equation instance. + * Ratings are mapped to partitions using the user/item IDs modulo the number of partitions. With + * two partitions, ratings with even-valued user IDs are shuffled to partition 0 while those with + * odd-valued user IDs are shuffled to partition 1: + * + * {{{ + * userInBlocks.collect() == Seq( + * 0 -> Seq( + * // Internally, the class stores the ratings in a more optimized format than + * // a sequence of `Rating`s, but for clarity we show it as such here. + * Rating(0, 1, 0.1f), + * Rating(0, 4, 0.4f), + * Rating(0, 7, 0.7f), + * Rating(6, 1, 6.1f), + * Rating(6, 4, 6.4f), + * Rating(6, 7, 6.7f) + * ), + * 1 -> Seq( + * Rating(3, 1, 3.1f), + * Rating(3, 4, 3.4f), + * Rating(3, 7, 3.7f) + * ) + * ) + * }}} + * + * Similarly, ratings with even-valued item IDs are shuffled to partition 0 while those with + * odd-valued item IDs are shuffled to partition 1: + * + * {{{ + * itemInBlocks.collect() == Seq( + * 0 -> Seq( + * Rating(0, 4, 0.4f), + * Rating(3, 4, 3.4f), + * Rating(6, 4, 6.4f) + * ), + * 1 -> Seq( + * Rating(0, 1, 0.1f), + * Rating(0, 7, 0.7f), + * Rating(3, 1, 3.1f), + * Rating(3, 7, 3.7f), + * Rating(6, 1, 6.1f), + * Rating(6, 7, 6.7f) + * ) + * ) + * }}} * * @param srcIds src ids (ordered) * @param dstPtrs dst pointers. Elements in range [dstPtrs(i), dstPtrs(i+1)) of dst indices and @@ -1026,7 +1188,24 @@ object ALS extends DefaultParamsReadable[ALS] with Logging { } /** - * Partitions raw ratings into blocks. + * Groups an RDD of [[Rating]]s by the user partition and item partition to which each `Rating` + * maps according to the given partitioners. The returned pair RDD holds the ratings, encoded in + * a memory-efficient format but otherwise unchanged, keyed by the (user partition ID, item + * partition ID) pair. + * + * Performance note: This is an expensive operation that performs an RDD shuffle. + * + * Implementation note: This implementation produces the same result as the following but + * generates fewer intermediate objects: + * + * {{{ + * ratings.map { r => + * ((srcPart.getPartition(r.user), dstPart.getPartition(r.item)), r) + * }.aggregateByKey(new RatingBlockBuilder)( + * seqOp = (b, r) => b.add(r), + * combOp = (b0, b1) => b0.merge(b1.build())) + * .mapValues(_.build()) + * }}} * * @param ratings raw ratings * @param srcPart partitioner for src IDs @@ -1037,17 +1216,6 @@ object ALS extends DefaultParamsReadable[ALS] with Logging { ratings: RDD[Rating[ID]], srcPart: Partitioner, dstPart: Partitioner): RDD[((Int, Int), RatingBlock[ID])] = { - - /* The implementation produces the same result as the following but generates less objects. - - ratings.map { r => - ((srcPart.getPartition(r.user), dstPart.getPartition(r.item)), r) - }.aggregateByKey(new RatingBlockBuilder)( - seqOp = (b, r) => b.add(r), - combOp = (b0, b1) => b0.merge(b1.build())) - .mapValues(_.build()) - */ - val numPartitions = srcPart.numPartitions * dstPart.numPartitions ratings.mapPartitions { iter => val builders = Array.fill(numPartitions)(new RatingBlockBuilder[ID]) @@ -1135,8 +1303,8 @@ object ALS extends DefaultParamsReadable[ALS] with Logging { def length: Int = srcIds.length /** - * Compresses the block into an [[InBlock]]. The algorithm is the same as converting a - * sparse matrix from coordinate list (COO) format into compressed sparse column (CSC) format. + * Compresses the block into an `InBlock`. The algorithm is the same as converting a sparse + * matrix from coordinate list (COO) format into compressed sparse column (CSC) format. * Sorting is done using Spark's built-in Timsort to avoid generating too many objects. */ def compress(): InBlock[ID] = { From 2cf83c47838115f71419ba5b9296c69ec1d746cd Mon Sep 17 00:00:00 2001 From: Steve Loughran Date: Sun, 7 May 2017 10:15:31 +0100 Subject: [PATCH 440/512] [SPARK-7481][BUILD] Add spark-hadoop-cloud module to pull in object store access. ## What changes were proposed in this pull request? Add a new `spark-hadoop-cloud` module and maven profile to pull in object store support from `hadoop-openstack`, `hadoop-aws` and `hadoop-azure` (Hadoop 2.7+) JARs, along with their dependencies, fixing up the dependencies so that everything works, in particular Jackson. It restores `s3n://` access to S3, adds its `s3a://` replacement, OpenStack `swift://` and azure `wasb://`. There's a documentation page, `cloud_integration.md`, which covers the basic details of using Spark with object stores, referring the reader to the supplier's own documentation, with specific warnings on security and the possible mismatch between a store's behavior and that of a filesystem. In particular, users are advised be very cautious when trying to use an object store as the destination of data, and to consult the documentation of the storage supplier and the connector. (this is the successor to #12004; I can't re-open it) ## How was this patch tested? Downstream tests exist in [https://github.com/steveloughran/spark-cloud-examples/tree/master/cloud-examples](https://github.com/steveloughran/spark-cloud-examples/tree/master/cloud-examples) Those verify that the dependencies are sufficient to allow downstream applications to work with s3a, azure wasb and swift storage connectors, and perform basic IO & dataframe operations thereon. All seems well. Manually clean build & verify that assembly contains the relevant aws-* hadoop-* artifacts on Hadoop 2.6; azure on a hadoop-2.7 profile. SBT build: `build/sbt -Phadoop-cloud -Phadoop-2.7 package` maven build `mvn install -Phadoop-cloud -Phadoop-2.7` This PR *does not* update `dev/deps/spark-deps-hadoop-2.7` or `dev/deps/spark-deps-hadoop-2.6`, because unless the hadoop-cloud profile is enabled, no extra JARs show up in the dependency list. The dependency check in Jenkins isn't setting the property, so the new JARs aren't visible. Author: Steve Loughran Author: Steve Loughran Closes #17834 from steveloughran/cloud/SPARK-7481-current. --- assembly/pom.xml | 14 +++ docs/cloud-integration.md | 200 ++++++++++++++++++++++++++++++++ docs/index.md | 1 + docs/rdd-programming-guide.md | 6 +- docs/storage-openstack-swift.md | 38 ++---- hadoop-cloud/pom.xml | 185 +++++++++++++++++++++++++++++ pom.xml | 7 ++ project/SparkBuild.scala | 4 +- 8 files changed, 424 insertions(+), 31 deletions(-) create mode 100644 docs/cloud-integration.md create mode 100644 hadoop-cloud/pom.xml diff --git a/assembly/pom.xml b/assembly/pom.xml index 742a4a1531e7..464af16e46f6 100644 --- a/assembly/pom.xml +++ b/assembly/pom.xml @@ -226,5 +226,19 @@ provided + + + + hadoop-cloud + + + org.apache.spark + spark-hadoop-cloud_${scala.binary.version} + ${project.version} + + + diff --git a/docs/cloud-integration.md b/docs/cloud-integration.md new file mode 100644 index 000000000000..751a192da4ff --- /dev/null +++ b/docs/cloud-integration.md @@ -0,0 +1,200 @@ +--- +layout: global +displayTitle: Integration with Cloud Infrastructures +title: Integration with Cloud Infrastructures +description: Introduction to cloud storage support in Apache Spark SPARK_VERSION_SHORT +--- + + +* This will become a table of contents (this text will be scraped). +{:toc} + +## Introduction + + +All major cloud providers offer persistent data storage in *object stores*. +These are not classic "POSIX" file systems. +In order to store hundreds of petabytes of data without any single points of failure, +object stores replace the classic filesystem directory tree +with a simpler model of `object-name => data`. To enable remote access, operations +on objects are usually offered as (slow) HTTP REST operations. + +Spark can read and write data in object stores through filesystem connectors implemented +in Hadoop or provided by the infrastructure suppliers themselves. +These connectors make the object stores look *almost* like filesystems, with directories and files +and the classic operations on them such as list, delete and rename. + + +### Important: Cloud Object Stores are Not Real Filesystems + +While the stores appear to be filesystems, underneath +they are still object stores, [and the difference is significant](https://hadoop.apache.org/docs/current/hadoop-project-dist/hadoop-common/filesystem/introduction.html) + +They cannot be used as a direct replacement for a cluster filesystem such as HDFS +*except where this is explicitly stated*. + +Key differences are: + +* Changes to stored objects may not be immediately visible, both in directory listings and actual data access. +* The means by which directories are emulated may make working with them slow. +* Rename operations may be very slow and, on failure, leave the store in an unknown state. +* Seeking within a file may require new HTTP calls, hurting performance. + +How does this affect Spark? + +1. Reading and writing data can be significantly slower than working with a normal filesystem. +1. Some directory structures may be very inefficient to scan during query split calculation. +1. The output of work may not be immediately visible to a follow-on query. +1. The rename-based algorithm by which Spark normally commits work when saving an RDD, DataFrame or Dataset + is potentially both slow and unreliable. + +For these reasons, it is not always safe to use an object store as a direct destination of queries, or as +an intermediate store in a chain of queries. Consult the documentation of the object store and its +connector to determine which uses are considered safe. + +In particular: *without some form of consistency layer, Amazon S3 cannot +be safely used as the direct destination of work with the normal rename-based committer.* + +### Installation + +With the relevant libraries on the classpath and Spark configured with valid credentials, +objects can be can be read or written by using their URLs as the path to data. +For example `sparkContext.textFile("s3a://landsat-pds/scene_list.gz")` will create +an RDD of the file `scene_list.gz` stored in S3, using the s3a connector. + +To add the relevant libraries to an application's classpath, include the `hadoop-cloud` +module and its dependencies. + +In Maven, add the following to the `pom.xml` file, assuming `spark.version` +is set to the chosen version of Spark: + +{% highlight xml %} + + ... + + org.apache.spark + hadoop-cloud_2.11 + ${spark.version} + + ... + +{% endhighlight %} + +Commercial products based on Apache Spark generally directly set up the classpath +for talking to cloud infrastructures, in which case this module may not be needed. + +### Authenticating + +Spark jobs must authenticate with the object stores to access data within them. + +1. When Spark is running in a cloud infrastructure, the credentials are usually automatically set up. +1. `spark-submit` reads the `AWS_ACCESS_KEY`, `AWS_SECRET_KEY` +and `AWS_SESSION_TOKEN` environment variables and sets the associated authentication options +for the `s3n` and `s3a` connectors to Amazon S3. +1. In a Hadoop cluster, settings may be set in the `core-site.xml` file. +1. Authentication details may be manually added to the Spark configuration in `spark-default.conf` +1. Alternatively, they can be programmatically set in the `SparkConf` instance used to configure +the application's `SparkContext`. + +*Important: never check authentication secrets into source code repositories, +especially public ones* + +Consult [the Hadoop documentation](https://hadoop.apache.org/docs/current/) for the relevant +configuration and security options. + +## Configuring + +Each cloud connector has its own set of configuration parameters, again, +consult the relevant documentation. + +### Recommended settings for writing to object stores + +For object stores whose consistency model means that rename-based commits are safe +use the `FileOutputCommitter` v2 algorithm for performance: + +``` +spark.hadoop.mapreduce.fileoutputcommitter.algorithm.version 2 +``` + +This does less renaming at the end of a job than the "version 1" algorithm. +As it still uses `rename()` to commit files, it is unsafe to use +when the object store does not have consistent metadata/listings. + +The committer can also be set to ignore failures when cleaning up temporary +files; this reduces the risk that a transient network problem is escalated into a +job failure: + +``` +spark.hadoop.mapreduce.fileoutputcommitter.cleanup-failures.ignored true +``` + +As storing temporary files can run up charges; delete +directories called `"_temporary"` on a regular basis to avoid this. + +### Parquet I/O Settings + +For optimal performance when working with Parquet data use the following settings: + +``` +spark.hadoop.parquet.enable.summary-metadata false +spark.sql.parquet.mergeSchema false +spark.sql.parquet.filterPushdown true +spark.sql.hive.metastorePartitionPruning true +``` + +These minimise the amount of data read during queries. + +### ORC I/O Settings + +For best performance when working with ORC data, use these settings: + +``` +spark.sql.orc.filterPushdown true +spark.sql.orc.splits.include.file.footer true +spark.sql.orc.cache.stripe.details.size 10000 +spark.sql.hive.metastorePartitionPruning true +``` + +Again, these minimise the amount of data read during queries. + +## Spark Streaming and Object Storage + +Spark Streaming can monitor files added to object stores, by +creating a `FileInputDStream` to monitor a path in the store through a call to +`StreamingContext.textFileStream()`. + +1. The time to scan for new files is proportional to the number of files +under the path, not the number of *new* files, so it can become a slow operation. +The size of the window needs to be set to handle this. + +1. Files only appear in an object store once they are completely written; there +is no need for a worklow of write-then-rename to ensure that files aren't picked up +while they are still being written. Applications can write straight to the monitored directory. + +1. Streams should only be checkpointed to an store implementing a fast and +atomic `rename()` operation Otherwise the checkpointing may be slow and potentially unreliable. + +## Further Reading + +Here is the documentation on the standard connectors both from Apache and the cloud providers. + +* [OpenStack Swift](https://hadoop.apache.org/docs/current/hadoop-openstack/index.html). Hadoop 2.6+ +* [Azure Blob Storage](https://hadoop.apache.org/docs/current/hadoop-aws/tools/hadoop-aws/index.html). Since Hadoop 2.7 +* [Azure Data Lake](https://hadoop.apache.org/docs/current/hadoop-azure-datalake/index.html). Since Hadoop 2.8 +* [Amazon S3 via S3A and S3N](https://hadoop.apache.org/docs/current/hadoop-aws/tools/hadoop-aws/index.html). Hadoop 2.6+ +* [Amazon EMR File System (EMRFS)](https://docs.aws.amazon.com/emr/latest/ManagementGuide/emr-fs.html). From Amazon +* [Google Cloud Storage Connector for Spark and Hadoop](https://cloud.google.com/hadoop/google-cloud-storage-connector). From Google + + diff --git a/docs/index.md b/docs/index.md index ad4f24ff1a5d..960b968454d0 100644 --- a/docs/index.md +++ b/docs/index.md @@ -126,6 +126,7 @@ options for deployment: * [Security](security.html): Spark security support * [Hardware Provisioning](hardware-provisioning.html): recommendations for cluster hardware * Integration with other storage systems: + * [Cloud Infrastructures](cloud-integration.html) * [OpenStack Swift](storage-openstack-swift.html) * [Building Spark](building-spark.html): build Spark using the Maven system * [Contributing to Spark](http://spark.apache.org/contributing.html) diff --git a/docs/rdd-programming-guide.md b/docs/rdd-programming-guide.md index e2bf2d7ca77c..52e59df9990e 100644 --- a/docs/rdd-programming-guide.md +++ b/docs/rdd-programming-guide.md @@ -323,7 +323,7 @@ One important parameter for parallel collections is the number of *partitions* t Spark can create distributed datasets from any storage source supported by Hadoop, including your local file system, HDFS, Cassandra, HBase, [Amazon S3](http://wiki.apache.org/hadoop/AmazonS3), etc. Spark supports text files, [SequenceFiles](http://hadoop.apache.org/common/docs/current/api/org/apache/hadoop/mapred/SequenceFileInputFormat.html), and any other Hadoop [InputFormat](http://hadoop.apache.org/docs/stable/api/org/apache/hadoop/mapred/InputFormat.html). -Text file RDDs can be created using `SparkContext`'s `textFile` method. This method takes an URI for the file (either a local path on the machine, or a `hdfs://`, `s3n://`, etc URI) and reads it as a collection of lines. Here is an example invocation: +Text file RDDs can be created using `SparkContext`'s `textFile` method. This method takes an URI for the file (either a local path on the machine, or a `hdfs://`, `s3a://`, etc URI) and reads it as a collection of lines. Here is an example invocation: {% highlight scala %} scala> val distFile = sc.textFile("data.txt") @@ -356,7 +356,7 @@ Apart from text files, Spark's Scala API also supports several other data format Spark can create distributed datasets from any storage source supported by Hadoop, including your local file system, HDFS, Cassandra, HBase, [Amazon S3](http://wiki.apache.org/hadoop/AmazonS3), etc. Spark supports text files, [SequenceFiles](http://hadoop.apache.org/common/docs/current/api/org/apache/hadoop/mapred/SequenceFileInputFormat.html), and any other Hadoop [InputFormat](http://hadoop.apache.org/docs/stable/api/org/apache/hadoop/mapred/InputFormat.html). -Text file RDDs can be created using `SparkContext`'s `textFile` method. This method takes an URI for the file (either a local path on the machine, or a `hdfs://`, `s3n://`, etc URI) and reads it as a collection of lines. Here is an example invocation: +Text file RDDs can be created using `SparkContext`'s `textFile` method. This method takes an URI for the file (either a local path on the machine, or a `hdfs://`, `s3a://`, etc URI) and reads it as a collection of lines. Here is an example invocation: {% highlight java %} JavaRDD distFile = sc.textFile("data.txt"); @@ -388,7 +388,7 @@ Apart from text files, Spark's Java API also supports several other data formats PySpark can create distributed datasets from any storage source supported by Hadoop, including your local file system, HDFS, Cassandra, HBase, [Amazon S3](http://wiki.apache.org/hadoop/AmazonS3), etc. Spark supports text files, [SequenceFiles](http://hadoop.apache.org/common/docs/current/api/org/apache/hadoop/mapred/SequenceFileInputFormat.html), and any other Hadoop [InputFormat](http://hadoop.apache.org/docs/stable/api/org/apache/hadoop/mapred/InputFormat.html). -Text file RDDs can be created using `SparkContext`'s `textFile` method. This method takes an URI for the file (either a local path on the machine, or a `hdfs://`, `s3n://`, etc URI) and reads it as a collection of lines. Here is an example invocation: +Text file RDDs can be created using `SparkContext`'s `textFile` method. This method takes an URI for the file (either a local path on the machine, or a `hdfs://`, `s3a://`, etc URI) and reads it as a collection of lines. Here is an example invocation: {% highlight python %} >>> distFile = sc.textFile("data.txt") diff --git a/docs/storage-openstack-swift.md b/docs/storage-openstack-swift.md index c39ef1ce59e1..f4bb2353e3c4 100644 --- a/docs/storage-openstack-swift.md +++ b/docs/storage-openstack-swift.md @@ -8,7 +8,8 @@ same URI formats as in Hadoop. You can specify a path in Swift as input through URI of the form swift://container.PROVIDER/path. You will also need to set your Swift security credentials, through core-site.xml or via SparkContext.hadoopConfiguration. -Current Swift driver requires Swift to use Keystone authentication method. +The current Swift driver requires Swift to use the Keystone authentication method, or +its Rackspace-specific predecessor. # Configuring Swift for Better Data Locality @@ -19,41 +20,30 @@ Although not mandatory, it is recommended to configure the proxy server of Swift # Dependencies -The Spark application should include hadoop-openstack dependency. +The Spark application should include hadoop-openstack dependency, which can +be done by including the `hadoop-cloud` module for the specific version of spark used. For example, for Maven support, add the following to the pom.xml file: {% highlight xml %} ... - org.apache.hadoop - hadoop-openstack - 2.3.0 + org.apache.spark + hadoop-cloud_2.11 + ${spark.version} ... {% endhighlight %} - # Configuration Parameters Create core-site.xml and place it inside Spark's conf directory. -There are two main categories of parameters that should to be configured: declaration of the -Swift driver and the parameters that are required by Keystone. +The main category of parameters that should be configured are the authentication parameters +required by Keystone. -Configuration of Hadoop to use Swift File system achieved via - - - - - - - -
    Property NameValue
    fs.swift.implorg.apache.hadoop.fs.swift.snative.SwiftNativeFileSystem
    - -Additional parameters required by Keystone (v2.0) and should be provided to the Swift driver. Those -parameters will be used to perform authentication in Keystone to access Swift. The following table -contains a list of Keystone mandatory parameters. PROVIDER can be any name. +The following table contains a list of Keystone mandatory parameters. PROVIDER can be +any (alphanumeric) name. @@ -94,7 +84,7 @@ contains a list of Keystone mandatory parameters. PROVIDER can be a - +
    Property NameMeaningRequired
    fs.swift.service.PROVIDER.publicIndicates if all URLs are publicIndicates whether to use the public (off cloud) or private (in cloud; no transfer fees) endpoints Mandatory
    @@ -104,10 +94,6 @@ defined for tenant test. Then core-site.xml should inc {% highlight xml %} - - fs.swift.impl - org.apache.hadoop.fs.swift.snative.SwiftNativeFileSystem - fs.swift.service.SparkTest.auth.url http://127.0.0.1:5000/v2.0/tokens diff --git a/hadoop-cloud/pom.xml b/hadoop-cloud/pom.xml new file mode 100644 index 000000000000..aa36dd4774d8 --- /dev/null +++ b/hadoop-cloud/pom.xml @@ -0,0 +1,185 @@ + + + + 4.0.0 + + org.apache.spark + spark-parent_2.11 + 2.3.0-SNAPSHOT + ../pom.xml + + + spark-hadoop-cloud_2.11 + jar + Spark Project Cloud Integration through Hadoop Libraries + + Contains support for cloud infrastructures, specifically the Hadoop JARs and + transitive dependencies needed to interact with the infrastructures, + making everything consistent with Spark's other dependencies. + + + hadoop-cloud + + + + + + org.apache.hadoop + hadoop-aws + ${hadoop.version} + ${hadoop.deps.scope} + + + org.apache.hadoop + hadoop-common + + + commons-logging + commons-logging + + + org.codehaus.jackson + jackson-mapper-asl + + + org.codehaus.jackson + jackson-core-asl + + + com.fasterxml.jackson.core + jackson-core + + + com.fasterxml.jackson.core + jackson-databind + + + com.fasterxml.jackson.core + jackson-annotations + + + + + org.apache.hadoop + hadoop-openstack + ${hadoop.version} + ${hadoop.deps.scope} + + + org.apache.hadoop + hadoop-common + + + commons-logging + commons-logging + + + junit + junit + + + org.mockito + mockito-all + + + + + + + joda-time + joda-time + ${hadoop.deps.scope} + + + + com.fasterxml.jackson.core + jackson-databind + ${hadoop.deps.scope} + + + com.fasterxml.jackson.core + jackson-annotations + ${hadoop.deps.scope} + + + com.fasterxml.jackson.dataformat + jackson-dataformat-cbor + ${fasterxml.jackson.version} + + + + org.apache.httpcomponents + httpclient + ${hadoop.deps.scope} + + + + org.apache.httpcomponents + httpcore + ${hadoop.deps.scope} + + + + + + + hadoop-2.7 + + + + + + org.apache.hadoop + hadoop-azure + ${hadoop.version} + ${hadoop.deps.scope} + + + org.apache.hadoop + hadoop-common + + + org.codehaus.jackson + jackson-mapper-asl + + + com.fasterxml.jackson.core + jackson-core + + + com.google.guava + guava + + + + + + + + + diff --git a/pom.xml b/pom.xml index a1a1817e2f7d..0533a8dcf2e0 100644 --- a/pom.xml +++ b/pom.xml @@ -2546,6 +2546,13 @@ + + hadoop-cloud + + hadoop-cloud + + + scala-2.10 diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index e52baf51aed1..b5362ec1ae45 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -57,9 +57,9 @@ object BuildCommons { ).map(ProjectRef(buildLocation, _)) ++ sqlProjects ++ streamingProjects val optionallyEnabledProjects@Seq(mesos, yarn, sparkGangliaLgpl, - streamingKinesisAsl, dockerIntegrationTests) = + streamingKinesisAsl, dockerIntegrationTests, hadoopCloud) = Seq("mesos", "yarn", "ganglia-lgpl", "streaming-kinesis-asl", - "docker-integration-tests").map(ProjectRef(buildLocation, _)) + "docker-integration-tests", "hadoop-cloud").map(ProjectRef(buildLocation, _)) val assemblyProjects@Seq(networkYarn, streamingFlumeAssembly, streamingKafkaAssembly, streamingKafka010Assembly, streamingKinesisAslAssembly) = Seq("network-yarn", "streaming-flume-assembly", "streaming-kafka-0-8-assembly", "streaming-kafka-0-10-assembly", "streaming-kinesis-asl-assembly") From 7087e01194964a1aad0b45bdb41506a17100eacf Mon Sep 17 00:00:00 2001 From: Felix Cheung Date: Sun, 7 May 2017 13:10:10 -0700 Subject: [PATCH 441/512] [SPARK-20543][SPARKR][FOLLOWUP] Don't skip tests on AppVeyor ## What changes were proposed in this pull request? add environment ## How was this patch tested? wait for appveyor run Author: Felix Cheung Closes #17878 from felixcheung/appveyorrcran. --- R/pkg/inst/tests/testthat/test_sparkSQL.R | 2 +- appveyor.yml | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index 47cc34a6c5b7..232246d6be9b 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -3387,7 +3387,7 @@ compare_list <- function(list1, list2) { # This should always be the **very last test** in this test file. test_that("No extra files are created in SPARK_HOME by starting session and making calls", { - skip_on_cran() + skip_on_cran() # skip because when run from R CMD check SPARK_HOME is not the current directory # Check that it is not creating any extra file. # Does not check the tempdir which would be cleaned up after. diff --git a/appveyor.yml b/appveyor.yml index 4d31af70f056..58c2e98289e9 100644 --- a/appveyor.yml +++ b/appveyor.yml @@ -48,6 +48,9 @@ install: build_script: - cmd: mvn -DskipTests -Psparkr -Phive -Phive-thriftserver package +environment: + NOT_CRAN: true + test_script: - cmd: .\bin\spark-submit2.cmd --driver-java-options "-Dlog4j.configuration=file:///%CD:\=/%/R/log4j.properties" --conf spark.hadoop.fs.defaultFS="file:///" R\pkg\tests\run-all.R @@ -56,4 +59,3 @@ notifications: on_build_success: false on_build_failure: false on_build_status_changed: false - From 500436b4368207db9e9b9cef83f9c11d33e31e1a Mon Sep 17 00:00:00 2001 From: Jacek Laskowski Date: Sun, 7 May 2017 13:56:13 -0700 Subject: [PATCH 442/512] [MINOR][SQL][DOCS] Improve unix_timestamp's scaladoc (and typo hunting) ## What changes were proposed in this pull request? * Docs are consistent (across different `unix_timestamp` variants and their internal expressions) * typo hunting ## How was this patch tested? local build Author: Jacek Laskowski Closes #17801 from jaceklaskowski/unix_timestamp. --- .../expressions/datetimeExpressions.scala | 6 ++--- .../sql/catalyst/util/DateTimeUtils.scala | 2 +- .../org/apache/spark/sql/functions.scala | 26 ++++++++++++------- 3 files changed, 21 insertions(+), 13 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index bb8fd5032d63..a98cd33f2780 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -488,7 +488,7 @@ case class DateFormatClass(left: Expression, right: Expression, timeZoneId: Opti * Deterministic version of [[UnixTimestamp]], must have at least one parameter. */ @ExpressionDescription( - usage = "_FUNC_(expr[, pattern]) - Returns the UNIX timestamp of the give time.", + usage = "_FUNC_(expr[, pattern]) - Returns the UNIX timestamp of the given time.", extended = """ Examples: > SELECT _FUNC_('2016-04-08', 'yyyy-MM-dd'); @@ -1225,8 +1225,8 @@ case class ParseToTimestamp(left: Expression, format: Expression, child: Express extends RuntimeReplaceable { def this(left: Expression, format: Expression) = { - this(left, format, Cast(UnixTimestamp(left, format), TimestampType)) -} + this(left, format, Cast(UnixTimestamp(left, format), TimestampType)) + } override def flatArguments: Iterator[Any] = Iterator(left, format) override def sql: String = s"$prettyName(${left.sql}, ${format.sql})" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index eb6aad5b2d2b..6c1592fd8881 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -423,7 +423,7 @@ object DateTimeUtils { } /** - * Parses a given UTF8 date string to the corresponding a corresponding [[Int]] value. + * Parses a given UTF8 date string to a corresponding [[Int]] value. * The return type is [[Option]] in order to distinguish between 0 and null. The following * formats are allowed: * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index f07e04368389..987011edfe1e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -2491,10 +2491,10 @@ object functions { * Converts a date/timestamp/string to a value of string in the format specified by the date * format given by the second argument. * - * A pattern could be for instance `dd.MM.yyyy` and could return a string like '18.03.1993'. All - * pattern letters of `java.text.SimpleDateFormat` can be used. + * A pattern `dd.MM.yyyy` would return a string like `18.03.1993`. + * All pattern letters of `java.text.SimpleDateFormat` can be used. * - * @note Use when ever possible specialized functions like [[year]]. These benefit from a + * @note Use specialized functions like [[year]] whenever possible as they benefit from a * specialized implementation. * * @group datetime_funcs @@ -2647,7 +2647,11 @@ object functions { } /** - * Gets current Unix timestamp in seconds. + * Returns the current Unix timestamp (in seconds). + * + * @note All calls of `unix_timestamp` within the same query return the same value + * (i.e. the current timestamp is calculated at the start of query evaluation). + * * @group datetime_funcs * @since 1.5.0 */ @@ -2657,7 +2661,9 @@ object functions { /** * Converts time string in format yyyy-MM-dd HH:mm:ss to Unix timestamp (in seconds), - * using the default timezone and the default locale, return null if fail. + * using the default timezone and the default locale. + * Returns `null` if fails. + * * @group datetime_funcs * @since 1.5.0 */ @@ -2666,13 +2672,15 @@ object functions { } /** - * Convert time string with given pattern - * (see [http://docs.oracle.com/javase/tutorial/i18n/format/simpleDateFormat.html]) - * to Unix time stamp (in seconds), return null if fail. + * Converts time string with given pattern to Unix timestamp (in seconds). + * Returns `null` if fails. + * + * @see + * Customizing Formats * @group datetime_funcs * @since 1.5.0 */ - def unix_timestamp(s: Column, p: String): Column = withExpr {UnixTimestamp(s.expr, Literal(p)) } + def unix_timestamp(s: Column, p: String): Column = withExpr { UnixTimestamp(s.expr, Literal(p)) } /** * Convert time string to a Unix timestamp (in seconds). From 1f73d3589a84b78473598c17ac328a9805896778 Mon Sep 17 00:00:00 2001 From: zero323 Date: Sun, 7 May 2017 16:24:42 -0700 Subject: [PATCH 443/512] [SPARK-20550][SPARKR] R wrapper for Dataset.alias ## What changes were proposed in this pull request? - Add SparkR wrapper for `Dataset.alias`. - Adjust roxygen annotations for `functions.alias` (including example usage). ## How was this patch tested? Unit tests, `check_cran.sh`. Author: zero323 Closes #17825 from zero323/SPARK-20550. --- R/pkg/R/DataFrame.R | 24 +++++++++++++++++++++++ R/pkg/R/column.R | 16 +++++++-------- R/pkg/R/generics.R | 11 +++++++++++ R/pkg/inst/tests/testthat/test_sparkSQL.R | 10 ++++++++++ 4 files changed, 53 insertions(+), 8 deletions(-) diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 1c8869202f67..b56dddcb9f2e 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -3745,3 +3745,27 @@ setMethod("hint", jdf <- callJMethod(x@sdf, "hint", name, parameters) dataFrame(jdf) }) + +#' alias +#' +#' @aliases alias,SparkDataFrame-method +#' @family SparkDataFrame functions +#' @rdname alias +#' @name alias +#' @export +#' @examples +#' \dontrun{ +#' df <- alias(createDataFrame(mtcars), "mtcars") +#' avg_mpg <- alias(agg(groupBy(df, df$cyl), avg(df$mpg)), "avg_mpg") +#' +#' head(select(df, column("mtcars.mpg"))) +#' head(join(df, avg_mpg, column("mtcars.cyl") == column("avg_mpg.cyl"))) +#' } +#' @note alias(SparkDataFrame) since 2.3.0 +setMethod("alias", + signature(object = "SparkDataFrame"), + function(object, data) { + stopifnot(is.character(data)) + sdf <- callJMethod(object@sdf, "alias", data) + dataFrame(sdf) + }) diff --git a/R/pkg/R/column.R b/R/pkg/R/column.R index 147ee4b6887b..574078012ada 100644 --- a/R/pkg/R/column.R +++ b/R/pkg/R/column.R @@ -130,19 +130,19 @@ createMethods <- function() { createMethods() -#' alias -#' -#' Set a new name for a column -#' -#' @param object Column to rename -#' @param data new name to use -#' #' @rdname alias #' @name alias #' @aliases alias,Column-method #' @family colum_func #' @export -#' @note alias since 1.4.0 +#' @examples \dontrun{ +#' df <- createDataFrame(iris) +#' +#' head(select( +#' df, alias(df$Sepal_Length, "slength"), alias(df$Petal_Length, "plength") +#' )) +#' } +#' @note alias(Column) since 1.4.0 setMethod("alias", signature(object = "Column"), function(object, data) { diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index e835ef3e4f40..3c84bf8a4803 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -387,6 +387,17 @@ setGeneric("value", function(bcast) { standardGeneric("value") }) #' @export setGeneric("agg", function (x, ...) { standardGeneric("agg") }) +#' alias +#' +#' Returns a new SparkDataFrame or a Column with an alias set. Equivalent to SQL "AS" keyword. +#' +#' @name alias +#' @rdname alias +#' @param object x a SparkDataFrame or a Column +#' @param data new name to use +#' @return a SparkDataFrame or a Column +NULL + #' @rdname arrange #' @export setGeneric("arrange", function(x, col, ...) { standardGeneric("arrange") }) diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index 232246d6be9b..0856bab5686c 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -1223,6 +1223,16 @@ test_that("select with column", { expect_equal(columns(df4), c("name", "age")) expect_equal(count(df4), 3) + # Test select with alias + df5 <- alias(df, "table") + + expect_equal(columns(select(df5, column("table.name"))), "name") + expect_equal(columns(select(df5, "table.name")), "name") + + # Test that stats::alias is not masked + expect_is(alias(aov(yield ~ block + N * P * K, npk)), "listof") + + expect_error(select(df, c("name", "age"), "name"), "To select multiple columns, use a character vector or list for col") }) From f53a820721fe0525c275e2bb4415c20909c42dc3 Mon Sep 17 00:00:00 2001 From: zero323 Date: Mon, 8 May 2017 10:58:27 +0800 Subject: [PATCH 444/512] [SPARK-16931][PYTHON][SQL] Add Python wrapper for bucketBy ## What changes were proposed in this pull request? Adds Python wrappers for `DataFrameWriter.bucketBy` and `DataFrameWriter.sortBy` ([SPARK-16931](https://issues.apache.org/jira/browse/SPARK-16931)) ## How was this patch tested? Unit tests covering new feature. __Note__: Based on work of GregBowyer (f49b9a23468f7af32cb53d2b654272757c151725) CC HyukjinKwon Author: zero323 Author: Greg Bowyer Closes #17077 from zero323/SPARK-16931. --- python/pyspark/sql/readwriter.py | 57 ++++++++++++++++++++++++++++++++ python/pyspark/sql/tests.py | 54 ++++++++++++++++++++++++++++++ 2 files changed, 111 insertions(+) diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 960fb882cf90..90ce8f81eb7f 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -563,6 +563,63 @@ def partitionBy(self, *cols): self._jwrite = self._jwrite.partitionBy(_to_seq(self._spark._sc, cols)) return self + @since(2.3) + def bucketBy(self, numBuckets, col, *cols): + """Buckets the output by the given columns.If specified, + the output is laid out on the file system similar to Hive's bucketing scheme. + + :param numBuckets: the number of buckets to save + :param col: a name of a column, or a list of names. + :param cols: additional names (optional). If `col` is a list it should be empty. + + .. note:: Applicable for file-based data sources in combination with + :py:meth:`DataFrameWriter.saveAsTable`. + + >>> (df.write.format('parquet') + ... .bucketBy(100, 'year', 'month') + ... .mode("overwrite") + ... .saveAsTable('bucketed_table')) + """ + if not isinstance(numBuckets, int): + raise TypeError("numBuckets should be an int, got {0}.".format(type(numBuckets))) + + if isinstance(col, (list, tuple)): + if cols: + raise ValueError("col is a {0} but cols are not empty".format(type(col))) + + col, cols = col[0], col[1:] + + if not all(isinstance(c, basestring) for c in cols) or not(isinstance(col, basestring)): + raise TypeError("all names should be `str`") + + self._jwrite = self._jwrite.bucketBy(numBuckets, col, _to_seq(self._spark._sc, cols)) + return self + + @since(2.3) + def sortBy(self, col, *cols): + """Sorts the output in each bucket by the given columns on the file system. + + :param col: a name of a column, or a list of names. + :param cols: additional names (optional). If `col` is a list it should be empty. + + >>> (df.write.format('parquet') + ... .bucketBy(100, 'year', 'month') + ... .sortBy('day') + ... .mode("overwrite") + ... .saveAsTable('sorted_bucketed_table')) + """ + if isinstance(col, (list, tuple)): + if cols: + raise ValueError("col is a {0} but cols are not empty".format(type(col))) + + col, cols = col[0], col[1:] + + if not all(isinstance(c, basestring) for c in cols) or not(isinstance(col, basestring)): + raise TypeError("all names should be `str`") + + self._jwrite = self._jwrite.sortBy(col, _to_seq(self._spark._sc, cols)) + return self + @since(1.4) def save(self, path=None, format=None, mode=None, partitionBy=None, **options): """Saves the contents of the :class:`DataFrame` to a data source. diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 7983bc536fc6..e3fe01eae243 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -211,6 +211,12 @@ def test_sqlcontext_reuses_sparksession(self): sqlContext2 = SQLContext(self.sc) self.assertTrue(sqlContext1.sparkSession is sqlContext2.sparkSession) + def tearDown(self): + super(SQLTests, self).tearDown() + + # tear down test_bucketed_write state + self.spark.sql("DROP TABLE IF EXISTS pyspark_bucket") + def test_row_should_be_read_only(self): row = Row(a=1, b=2) self.assertEqual(1, row.a) @@ -2196,6 +2202,54 @@ def test_BinaryType_serialization(self): df = self.spark.createDataFrame(data, schema=schema) df.collect() + def test_bucketed_write(self): + data = [ + (1, "foo", 3.0), (2, "foo", 5.0), + (3, "bar", -1.0), (4, "bar", 6.0), + ] + df = self.spark.createDataFrame(data, ["x", "y", "z"]) + + def count_bucketed_cols(names, table="pyspark_bucket"): + """Given a sequence of column names and a table name + query the catalog and return number o columns which are + used for bucketing + """ + cols = self.spark.catalog.listColumns(table) + num = len([c for c in cols if c.name in names and c.isBucket]) + return num + + # Test write with one bucketing column + df.write.bucketBy(3, "x").mode("overwrite").saveAsTable("pyspark_bucket") + self.assertEqual(count_bucketed_cols(["x"]), 1) + self.assertSetEqual(set(data), set(self.spark.table("pyspark_bucket").collect())) + + # Test write two bucketing columns + df.write.bucketBy(3, "x", "y").mode("overwrite").saveAsTable("pyspark_bucket") + self.assertEqual(count_bucketed_cols(["x", "y"]), 2) + self.assertSetEqual(set(data), set(self.spark.table("pyspark_bucket").collect())) + + # Test write with bucket and sort + df.write.bucketBy(2, "x").sortBy("z").mode("overwrite").saveAsTable("pyspark_bucket") + self.assertEqual(count_bucketed_cols(["x"]), 1) + self.assertSetEqual(set(data), set(self.spark.table("pyspark_bucket").collect())) + + # Test write with a list of columns + df.write.bucketBy(3, ["x", "y"]).mode("overwrite").saveAsTable("pyspark_bucket") + self.assertEqual(count_bucketed_cols(["x", "y"]), 2) + self.assertSetEqual(set(data), set(self.spark.table("pyspark_bucket").collect())) + + # Test write with bucket and sort with a list of columns + (df.write.bucketBy(2, "x") + .sortBy(["y", "z"]) + .mode("overwrite").saveAsTable("pyspark_bucket")) + self.assertSetEqual(set(data), set(self.spark.table("pyspark_bucket").collect())) + + # Test write with bucket and sort with multiple columns + (df.write.bucketBy(2, "x") + .sortBy("y", "z") + .mode("overwrite").saveAsTable("pyspark_bucket")) + self.assertSetEqual(set(data), set(self.spark.table("pyspark_bucket").collect())) + class HiveSparkSubmitTests(SparkSubmitTests): From 22691556e5f0dfbac81b8cc9ca0a67c70c1711ca Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Mon, 8 May 2017 12:16:00 +0900 Subject: [PATCH 445/512] [SPARK-12297][SQL] Hive compatibility for Parquet Timestamps ## What changes were proposed in this pull request? This change allows timestamps in parquet-based hive table to behave as a "floating time", without a timezone, as timestamps are for other file formats. If the storage timezone is the same as the session timezone, this conversion is a no-op. When data is read from a hive table, the table property is *always* respected. This allows spark to not change behavior when reading old data, but read newly written data correctly (whatever the source of the data is). Spark inherited the original behavior from Hive, but Hive is also updating behavior to use the same scheme in HIVE-12767 / HIVE-16231. The default for Spark remains unchanged; created tables do not include the new table property. This will only apply to hive tables; nothing is added to parquet metadata to indicate the timezone, so data that is read or written directly from parquet files will never have any conversions applied. ## How was this patch tested? Added a unit test which creates tables, reads and writes data, under a variety of permutations (different storage timezones, different session timezones, vectorized reading on and off). Author: Imran Rashid Closes #16781 from squito/SPARK-12297. --- .../sql/catalyst/catalog/interface.scala | 4 +- .../sql/catalyst/util/DateTimeUtils.scala | 5 + .../parquet/VectorizedColumnReader.java | 28 +- .../VectorizedParquetRecordReader.java | 6 +- .../spark/sql/execution/command/tables.scala | 8 +- .../parquet/ParquetFileFormat.scala | 2 + .../parquet/ParquetReadSupport.scala | 3 +- .../parquet/ParquetRecordMaterializer.scala | 9 +- .../parquet/ParquetRowConverter.scala | 53 ++- .../parquet/ParquetWriteSupport.scala | 25 +- .../spark/sql/hive/HiveExternalCatalog.scala | 11 +- .../spark/sql/hive/HiveMetastoreCatalog.scala | 12 +- .../hive/ParquetHiveCompatibilitySuite.scala | 379 +++++++++++++++++- 13 files changed, 516 insertions(+), 29 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala index cc0cbba275b8..c39017ebbfe6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala @@ -132,10 +132,10 @@ case class CatalogTablePartition( /** * Given the partition schema, returns a row with that schema holding the partition values. */ - def toRow(partitionSchema: StructType, defaultTimeZondId: String): InternalRow = { + def toRow(partitionSchema: StructType, defaultTimeZoneId: String): InternalRow = { val caseInsensitiveProperties = CaseInsensitiveMap(storage.properties) val timeZoneId = caseInsensitiveProperties.getOrElse( - DateTimeUtils.TIMEZONE_OPTION, defaultTimeZondId) + DateTimeUtils.TIMEZONE_OPTION, defaultTimeZoneId) InternalRow.fromSeq(partitionSchema.map { field => val partValue = if (spec(field.name) == ExternalCatalogUtils.DEFAULT_PARTITION_NAME) { null diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index 6c1592fd8881..bf596fa0a89d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -498,6 +498,11 @@ object DateTimeUtils { false } + lazy val validTimezones = TimeZone.getAvailableIDs().toSet + def isValidTimezone(timezoneId: String): Boolean = { + validTimezones.contains(timezoneId) + } + /** * Returns the microseconds since year zero (-17999) from microseconds since epoch. */ diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java index 9d641b528723..dabbc2b6387e 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java @@ -18,7 +18,9 @@ package org.apache.spark.sql.execution.datasources.parquet; import java.io.IOException; +import java.util.TimeZone; +import org.apache.hadoop.conf.Configuration; import org.apache.parquet.bytes.BytesUtils; import org.apache.parquet.column.ColumnDescriptor; import org.apache.parquet.column.Dictionary; @@ -30,6 +32,7 @@ import org.apache.spark.sql.catalyst.util.DateTimeUtils; import org.apache.spark.sql.execution.vectorized.ColumnVector; +import org.apache.spark.sql.internal.SQLConf; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.DecimalType; @@ -90,11 +93,30 @@ public class VectorizedColumnReader { private final PageReader pageReader; private final ColumnDescriptor descriptor; + private final TimeZone storageTz; + private final TimeZone sessionTz; - public VectorizedColumnReader(ColumnDescriptor descriptor, PageReader pageReader) + public VectorizedColumnReader(ColumnDescriptor descriptor, PageReader pageReader, + Configuration conf) throws IOException { this.descriptor = descriptor; this.pageReader = pageReader; + // If the table has a timezone property, apply the correct conversions. See SPARK-12297. + // The conf is sometimes null in tests. + String sessionTzString = + conf == null ? null : conf.get(SQLConf.SESSION_LOCAL_TIMEZONE().key()); + if (sessionTzString == null || sessionTzString.isEmpty()) { + sessionTz = DateTimeUtils.defaultTimeZone(); + } else { + sessionTz = TimeZone.getTimeZone(sessionTzString); + } + String storageTzString = + conf == null ? null : conf.get(ParquetFileFormat.PARQUET_TIMEZONE_TABLE_PROPERTY()); + if (storageTzString == null || storageTzString.isEmpty()) { + storageTz = sessionTz; + } else { + storageTz = TimeZone.getTimeZone(storageTzString); + } this.maxDefLevel = descriptor.getMaxDefinitionLevel(); DictionaryPage dictionaryPage = pageReader.readDictionaryPage(); @@ -289,7 +311,7 @@ private void decodeDictionaryIds(int rowId, int num, ColumnVector column, // TODO: Convert dictionary of Binaries to dictionary of Longs if (!column.isNullAt(i)) { Binary v = dictionary.decodeToBinary(dictionaryIds.getDictId(i)); - column.putLong(i, ParquetRowConverter.binaryToSQLTimestamp(v)); + column.putLong(i, ParquetRowConverter.binaryToSQLTimestamp(v, sessionTz, storageTz)); } } } else { @@ -422,7 +444,7 @@ private void readBinaryBatch(int rowId, int num, ColumnVector column) throws IOE if (defColumn.readInteger() == maxDefLevel) { column.putLong(rowId + i, // Read 12 bytes for INT96 - ParquetRowConverter.binaryToSQLTimestamp(data.readBinary(12))); + ParquetRowConverter.binaryToSQLTimestamp(data.readBinary(12), sessionTz, storageTz)); } else { column.putNull(rowId + i); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java index 51bdf0f0f229..d8974ddf2470 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java @@ -21,6 +21,7 @@ import java.util.Arrays; import java.util.List; +import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.mapreduce.InputSplit; import org.apache.hadoop.mapreduce.TaskAttemptContext; import org.apache.parquet.column.ColumnDescriptor; @@ -95,6 +96,8 @@ public class VectorizedParquetRecordReader extends SpecificParquetRecordReaderBa */ private boolean returnColumnarBatch; + private Configuration conf; + /** * The default config on whether columnarBatch should be offheap. */ @@ -107,6 +110,7 @@ public class VectorizedParquetRecordReader extends SpecificParquetRecordReaderBa public void initialize(InputSplit inputSplit, TaskAttemptContext taskAttemptContext) throws IOException, InterruptedException, UnsupportedOperationException { super.initialize(inputSplit, taskAttemptContext); + this.conf = taskAttemptContext.getConfiguration(); initializeInternal(); } @@ -277,7 +281,7 @@ private void checkEndOfRowGroup() throws IOException { for (int i = 0; i < columns.size(); ++i) { if (missingColumns[i]) continue; columnReaders[i] = new VectorizedColumnReader(columns.get(i), - pages.getPageReader(columns.get(i))); + pages.getPageReader(columns.get(i)), conf); } totalCountLoadedSoFar += pages.getRowCount(); } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala index ebf03e1bf886..5843c5b56d44 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala @@ -26,7 +26,6 @@ import scala.collection.mutable.ArrayBuffer import scala.util.control.NonFatal import scala.util.Try -import org.apache.commons.lang3.StringEscapeUtils import org.apache.hadoop.fs.Path import org.apache.spark.sql.{AnalysisException, Row, SparkSession} @@ -37,7 +36,7 @@ import org.apache.spark.sql.catalyst.catalog.CatalogTableType._ import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} import org.apache.spark.sql.catalyst.util.quoteIdentifier -import org.apache.spark.sql.execution.datasources.{DataSource, FileFormat, PartitioningUtils} +import org.apache.spark.sql.execution.datasources.{DataSource, PartitioningUtils} import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat import org.apache.spark.sql.execution.datasources.json.JsonFileFormat import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat @@ -74,6 +73,10 @@ case class CreateTableLikeCommand( sourceTableDesc.provider } + val properties = sourceTableDesc.properties.filter { case (k, _) => + k == ParquetFileFormat.PARQUET_TIMEZONE_TABLE_PROPERTY + } + // If the location is specified, we create an external table internally. // Otherwise create a managed table. val tblType = if (location.isEmpty) CatalogTableType.MANAGED else CatalogTableType.EXTERNAL @@ -86,6 +89,7 @@ case class CreateTableLikeCommand( locationUri = location.map(CatalogUtils.stringToURI(_))), schema = sourceTableDesc.schema, provider = newProvider, + properties = properties, partitionColumnNames = sourceTableDesc.partitionColumnNames, bucketSpec = sourceTableDesc.bucketSpec) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala index 2f3a2c62b912..8113768cd793 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala @@ -632,4 +632,6 @@ object ParquetFileFormat extends Logging { Failure(cause) }.toOption } + + val PARQUET_TIMEZONE_TABLE_PROPERTY = "parquet.mr.int96.write.zone" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala index f1a35dd8a620..bf395a0bef74 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala @@ -95,7 +95,8 @@ private[parquet] class ParquetReadSupport extends ReadSupport[UnsafeRow] with Lo new ParquetRecordMaterializer( parquetRequestedSchema, ParquetReadSupport.expandUDT(catalystRequestedSchema), - new ParquetSchemaConverter(conf)) + new ParquetSchemaConverter(conf), + conf) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRecordMaterializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRecordMaterializer.scala index 4e49a0dac97c..df041996cdea 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRecordMaterializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRecordMaterializer.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution.datasources.parquet +import org.apache.hadoop.conf.Configuration import org.apache.parquet.io.api.{GroupConverter, RecordMaterializer} import org.apache.parquet.schema.MessageType @@ -29,13 +30,17 @@ import org.apache.spark.sql.types.StructType * @param parquetSchema Parquet schema of the records to be read * @param catalystSchema Catalyst schema of the rows to be constructed * @param schemaConverter A Parquet-Catalyst schema converter that helps initializing row converters + * @param hadoopConf hadoop Configuration for passing extra params for parquet conversion */ private[parquet] class ParquetRecordMaterializer( - parquetSchema: MessageType, catalystSchema: StructType, schemaConverter: ParquetSchemaConverter) + parquetSchema: MessageType, + catalystSchema: StructType, + schemaConverter: ParquetSchemaConverter, + hadoopConf: Configuration) extends RecordMaterializer[UnsafeRow] { private val rootConverter = - new ParquetRowConverter(schemaConverter, parquetSchema, catalystSchema, NoopUpdater) + new ParquetRowConverter(schemaConverter, parquetSchema, catalystSchema, hadoopConf, NoopUpdater) override def getCurrentRecord: UnsafeRow = rootConverter.currentRecord diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala index 32e6c60cd976..d52ff62d93b2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala @@ -19,10 +19,12 @@ package org.apache.spark.sql.execution.datasources.parquet import java.math.{BigDecimal, BigInteger} import java.nio.ByteOrder +import java.util.TimeZone import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer +import org.apache.hadoop.conf.Configuration import org.apache.parquet.column.Dictionary import org.apache.parquet.io.api.{Binary, Converter, GroupConverter, PrimitiveConverter} import org.apache.parquet.schema.{GroupType, MessageType, OriginalType, Type} @@ -34,6 +36,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils, GenericArrayData} import org.apache.spark.sql.catalyst.util.DateTimeUtils.SQLTimestamp +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -117,12 +120,14 @@ private[parquet] class ParquetPrimitiveConverter(val updater: ParentContainerUpd * @param parquetType Parquet schema of Parquet records * @param catalystType Spark SQL schema that corresponds to the Parquet record type. User-defined * types should have been expanded. + * @param hadoopConf a hadoop Configuration for passing any extra parameters for parquet conversion * @param updater An updater which propagates converted field values to the parent container */ private[parquet] class ParquetRowConverter( schemaConverter: ParquetSchemaConverter, parquetType: GroupType, catalystType: StructType, + hadoopConf: Configuration, updater: ParentContainerUpdater) extends ParquetGroupConverter(updater) with Logging { @@ -261,18 +266,18 @@ private[parquet] class ParquetRowConverter( case TimestampType => // TODO Implements `TIMESTAMP_MICROS` once parquet-mr has that. + // If the table has a timezone property, apply the correct conversions. See SPARK-12297. + val sessionTzString = hadoopConf.get(SQLConf.SESSION_LOCAL_TIMEZONE.key) + val sessionTz = Option(sessionTzString).map(TimeZone.getTimeZone(_)) + .getOrElse(DateTimeUtils.defaultTimeZone()) + val storageTzString = hadoopConf.get(ParquetFileFormat.PARQUET_TIMEZONE_TABLE_PROPERTY) + val storageTz = Option(storageTzString).map(TimeZone.getTimeZone(_)).getOrElse(sessionTz) new ParquetPrimitiveConverter(updater) { // Converts nanosecond timestamps stored as INT96 override def addBinary(value: Binary): Unit = { - assert( - value.length() == 12, - "Timestamps (with nanoseconds) are expected to be stored in 12-byte long binaries, " + - s"but got a ${value.length()}-byte binary.") - - val buf = value.toByteBuffer.order(ByteOrder.LITTLE_ENDIAN) - val timeOfDayNanos = buf.getLong - val julianDay = buf.getInt - updater.setLong(DateTimeUtils.fromJulianDay(julianDay, timeOfDayNanos)) + val timestamp = ParquetRowConverter.binaryToSQLTimestamp(value, sessionTz = sessionTz, + storageTz = storageTz) + updater.setLong(timestamp) } } @@ -302,7 +307,7 @@ private[parquet] class ParquetRowConverter( case t: StructType => new ParquetRowConverter( - schemaConverter, parquetType.asGroupType(), t, new ParentContainerUpdater { + schemaConverter, parquetType.asGroupType(), t, hadoopConf, new ParentContainerUpdater { override def set(value: Any): Unit = updater.set(value.asInstanceOf[InternalRow].copy()) }) @@ -651,6 +656,7 @@ private[parquet] class ParquetRowConverter( } private[parquet] object ParquetRowConverter { + def binaryToUnscaledLong(binary: Binary): Long = { // The underlying `ByteBuffer` implementation is guaranteed to be `HeapByteBuffer`, so here // we are using `Binary.toByteBuffer.array()` to steal the underlying byte array without @@ -673,12 +679,35 @@ private[parquet] object ParquetRowConverter { unscaled } - def binaryToSQLTimestamp(binary: Binary): SQLTimestamp = { + /** + * Converts an int96 to a SQLTimestamp, given both the storage timezone and the local timezone. + * The timestamp is really meant to be interpreted as a "floating time", but since we + * actually store it as micros since epoch, why we have to apply a conversion when timezones + * change. + * + * @param binary a parquet Binary which holds one int96 + * @param sessionTz the session timezone. This will be used to determine how to display the time, + * and compute functions on the timestamp which involve a timezone, eg. extract + * the hour. + * @param storageTz the timezone which was used to store the timestamp. This should come from the + * timestamp table property, or else assume its the same as the sessionTz + * @return a timestamp (millis since epoch) which will render correctly in the sessionTz + */ + def binaryToSQLTimestamp( + binary: Binary, + sessionTz: TimeZone, + storageTz: TimeZone): SQLTimestamp = { assert(binary.length() == 12, s"Timestamps (with nanoseconds) are expected to be stored in" + s" 12-byte long binaries. Found a ${binary.length()}-byte binary instead.") val buffer = binary.toByteBuffer.order(ByteOrder.LITTLE_ENDIAN) val timeOfDayNanos = buffer.getLong val julianDay = buffer.getInt - DateTimeUtils.fromJulianDay(julianDay, timeOfDayNanos) + val utcEpochMicros = DateTimeUtils.fromJulianDay(julianDay, timeOfDayNanos) + // avoid expensive time logic if possible. + if (sessionTz.getID() != storageTz.getID()) { + DateTimeUtils.convertTz(utcEpochMicros, sessionTz, storageTz) + } else { + utcEpochMicros + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetWriteSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetWriteSupport.scala index 38b0e33937f3..679ed8e361b7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetWriteSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetWriteSupport.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.datasources.parquet import java.nio.{ByteBuffer, ByteOrder} import java.util +import java.util.TimeZone import scala.collection.JavaConverters.mapAsJavaMapConverter @@ -75,6 +76,9 @@ private[parquet] class ParquetWriteSupport extends WriteSupport[InternalRow] wit // Reusable byte array used to write decimal values private val decimalBuffer = new Array[Byte](minBytesForPrecision(DecimalType.MAX_PRECISION)) + private var storageTz: TimeZone = _ + private var sessionTz: TimeZone = _ + override def init(configuration: Configuration): WriteContext = { val schemaString = configuration.get(ParquetWriteSupport.SPARK_ROW_SCHEMA) this.schema = StructType.fromString(schemaString) @@ -91,6 +95,19 @@ private[parquet] class ParquetWriteSupport extends WriteSupport[InternalRow] wit this.rootFieldWriters = schema.map(_.dataType).map(makeWriter) + // If the table has a timezone property, apply the correct conversions. See SPARK-12297. + val sessionTzString = configuration.get(SQLConf.SESSION_LOCAL_TIMEZONE.key) + sessionTz = if (sessionTzString == null || sessionTzString == "") { + TimeZone.getDefault() + } else { + TimeZone.getTimeZone(sessionTzString) + } + val storageTzString = configuration.get(ParquetFileFormat.PARQUET_TIMEZONE_TABLE_PROPERTY) + storageTz = if (storageTzString == null || storageTzString == "") { + sessionTz + } else { + TimeZone.getTimeZone(storageTzString) + } val messageType = new ParquetSchemaConverter(configuration).convert(schema) val metadata = Map(ParquetReadSupport.SPARK_METADATA_KEY -> schemaString).asJava @@ -178,7 +195,13 @@ private[parquet] class ParquetWriteSupport extends WriteSupport[InternalRow] wit // NOTE: Starting from Spark 1.5, Spark SQL `TimestampType` only has microsecond // precision. Nanosecond parts of timestamp values read from INT96 are simply stripped. - val (julianDay, timeOfDayNanos) = DateTimeUtils.toJulianDay(row.getLong(ordinal)) + val rawMicros = row.getLong(ordinal) + val adjustedMicros = if (sessionTz.getID() == storageTz.getID()) { + rawMicros + } else { + DateTimeUtils.convertTz(rawMicros, storageTz, sessionTz) + } + val (julianDay, timeOfDayNanos) = DateTimeUtils.toJulianDay(adjustedMicros) val buf = ByteBuffer.wrap(timestampBuffer) buf.order(ByteOrder.LITTLE_ENDIAN).putLong(timeOfDayNanos).putInt(julianDay) recordConsumer.addBinary(Binary.fromReusedByteArray(timestampBuffer)) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index ba48facff293..8fef467f5f5c 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -39,9 +39,10 @@ import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.ColumnStat -import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap +import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.PartitioningUtils +import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat import org.apache.spark.sql.hive.client.HiveClient import org.apache.spark.sql.internal.HiveSerDe import org.apache.spark.sql.internal.StaticSQLConf._ @@ -224,6 +225,14 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat throw new TableAlreadyExistsException(db = db, table = table) } + val tableTz = tableDefinition.properties.get(ParquetFileFormat.PARQUET_TIMEZONE_TABLE_PROPERTY) + tableTz.foreach { tz => + if (!DateTimeUtils.isValidTimezone(tz)) { + throw new AnalysisException(s"Cannot set" + + s" ${ParquetFileFormat.PARQUET_TIMEZONE_TABLE_PROPERTY} to invalid timezone $tz") + } + } + if (tableDefinition.tableType == VIEW) { client.createTable(tableDefinition, ignoreIfExists) } else { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 6b98066cb76c..e0b565c0d79a 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.{QualifiedTableName, TableIdentifier} import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.datasources._ +import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat import org.apache.spark.sql.internal.SQLConf.HiveCaseSensitiveInferenceMode._ import org.apache.spark.sql.types._ @@ -174,7 +175,7 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log // We don't support hive bucketed tables, only ones we write out. bucketSpec = None, fileFormat = fileFormat, - options = options)(sparkSession = sparkSession) + options = options ++ getStorageTzOptions(relation))(sparkSession = sparkSession) val created = LogicalRelation(fsRelation, updatedTable) tableRelationCache.put(tableIdentifier, created) created @@ -201,7 +202,7 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log userSpecifiedSchema = Option(dataSchema), // We don't support hive bucketed tables, only ones we write out. bucketSpec = None, - options = options, + options = options ++ getStorageTzOptions(relation), className = fileType).resolveRelation(), table = updatedTable) @@ -222,6 +223,13 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log result.copy(output = newOutput) } + private def getStorageTzOptions(relation: CatalogRelation): Map[String, String] = { + // We add the table timezone to the relation options, which automatically gets injected into the + // hadoopConf for the Parquet Converters + val storageTzKey = ParquetFileFormat.PARQUET_TIMEZONE_TABLE_PROPERTY + relation.tableMeta.properties.get(storageTzKey).map(storageTzKey -> _).toMap + } + private def inferIfNeeded( relation: CatalogRelation, options: Map[String, String], diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala index 05b6059472f5..2bfd63d9b56e 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala @@ -17,12 +17,22 @@ package org.apache.spark.sql.hive +import java.io.File +import java.net.URLDecoder import java.sql.Timestamp +import java.util.TimeZone -import org.apache.spark.sql.Row -import org.apache.spark.sql.execution.datasources.parquet.ParquetCompatibilityTest +import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.parquet.hadoop.ParquetFileReader +import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName + +import org.apache.spark.sql.{AnalysisException, Dataset, Row, SparkSession} +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.execution.datasources.parquet.{ParquetCompatibilityTest, ParquetFileFormat} +import org.apache.spark.sql.functions._ import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.{StringType, StructType, TimestampType} class ParquetHiveCompatibilitySuite extends ParquetCompatibilityTest with TestHiveSingleton { /** @@ -141,4 +151,369 @@ class ParquetHiveCompatibilitySuite extends ParquetCompatibilityTest with TestHi Row(Seq(Row(1))), "ARRAY>") } + + val testTimezones = Seq( + "UTC" -> "UTC", + "LA" -> "America/Los_Angeles", + "Berlin" -> "Europe/Berlin" + ) + // Check creating parquet tables with timestamps, writing data into them, and reading it back out + // under a variety of conditions: + // * tables with explicit tz and those without + // * altering table properties directly + // * variety of timezones, local & non-local + val sessionTimezones = testTimezones.map(_._2).map(Some(_)) ++ Seq(None) + sessionTimezones.foreach { sessionTzOpt => + val sparkSession = spark.newSession() + sessionTzOpt.foreach { tz => sparkSession.conf.set(SQLConf.SESSION_LOCAL_TIMEZONE.key, tz) } + testCreateWriteRead(sparkSession, "no_tz", None, sessionTzOpt) + val localTz = TimeZone.getDefault.getID() + testCreateWriteRead(sparkSession, "local", Some(localTz), sessionTzOpt) + // check with a variety of timezones. The unit tests currently are configured to always use + // America/Los_Angeles, but even if they didn't, we'd be sure to cover a non-local timezone. + testTimezones.foreach { case (tableName, zone) => + if (zone != localTz) { + testCreateWriteRead(sparkSession, tableName, Some(zone), sessionTzOpt) + } + } + } + + private def testCreateWriteRead( + sparkSession: SparkSession, + baseTable: String, + explicitTz: Option[String], + sessionTzOpt: Option[String]): Unit = { + testCreateAlterTablesWithTimezone(sparkSession, baseTable, explicitTz, sessionTzOpt) + testWriteTablesWithTimezone(sparkSession, baseTable, explicitTz, sessionTzOpt) + testReadTablesWithTimezone(sparkSession, baseTable, explicitTz, sessionTzOpt) + } + + private def checkHasTz(spark: SparkSession, table: String, tz: Option[String]): Unit = { + val tableMetadata = spark.sessionState.catalog.getTableMetadata(TableIdentifier(table)) + assert(tableMetadata.properties.get(ParquetFileFormat.PARQUET_TIMEZONE_TABLE_PROPERTY) === tz) + } + + private def testCreateAlterTablesWithTimezone( + spark: SparkSession, + baseTable: String, + explicitTz: Option[String], + sessionTzOpt: Option[String]): Unit = { + test(s"SPARK-12297: Create and Alter Parquet tables and timezones; explicitTz = $explicitTz; " + + s"sessionTzOpt = $sessionTzOpt") { + val key = ParquetFileFormat.PARQUET_TIMEZONE_TABLE_PROPERTY + withTable(baseTable, s"like_$baseTable", s"select_$baseTable", s"partitioned_$baseTable") { + // If we ever add a property to set the table timezone by default, defaultTz would change + val defaultTz = None + // check that created tables have correct TBLPROPERTIES + val tblProperties = explicitTz.map { + tz => s"""TBLPROPERTIES ($key="$tz")""" + }.getOrElse("") + spark.sql( + s"""CREATE TABLE $baseTable ( + | x int + | ) + | STORED AS PARQUET + | $tblProperties + """.stripMargin) + val expectedTableTz = explicitTz.orElse(defaultTz) + checkHasTz(spark, baseTable, expectedTableTz) + spark.sql( + s"""CREATE TABLE partitioned_$baseTable ( + | x int + | ) + | PARTITIONED BY (y int) + | STORED AS PARQUET + | $tblProperties + """.stripMargin) + checkHasTz(spark, s"partitioned_$baseTable", expectedTableTz) + spark.sql(s"CREATE TABLE like_$baseTable LIKE $baseTable") + checkHasTz(spark, s"like_$baseTable", expectedTableTz) + spark.sql( + s"""CREATE TABLE select_$baseTable + | STORED AS PARQUET + | AS + | SELECT * from $baseTable + """.stripMargin) + checkHasTz(spark, s"select_$baseTable", defaultTz) + + // check alter table, setting, unsetting, resetting the property + spark.sql( + s"""ALTER TABLE $baseTable SET TBLPROPERTIES ($key="America/Los_Angeles")""") + checkHasTz(spark, baseTable, Some("America/Los_Angeles")) + spark.sql(s"""ALTER TABLE $baseTable SET TBLPROPERTIES ($key="UTC")""") + checkHasTz(spark, baseTable, Some("UTC")) + spark.sql(s"""ALTER TABLE $baseTable UNSET TBLPROPERTIES ($key)""") + checkHasTz(spark, baseTable, None) + explicitTz.foreach { tz => + spark.sql(s"""ALTER TABLE $baseTable SET TBLPROPERTIES ($key="$tz")""") + checkHasTz(spark, baseTable, expectedTableTz) + } + } + } + } + + val desiredTimestampStrings = Seq( + "2015-12-31 22:49:59.123", + "2015-12-31 23:50:59.123", + "2016-01-01 00:39:59.123", + "2016-01-01 01:29:59.123" + ) + // We don't want to mess with timezones inside the tests themselves, since we use a shared + // spark context, and then we might be prone to issues from lazy vals for timezones. Instead, + // we manually adjust the timezone just to determine what the desired millis (since epoch, in utc) + // is for various "wall-clock" times in different timezones, and then we can compare against those + // in our tests. + val timestampTimezoneToMillis = { + val originalTz = TimeZone.getDefault + try { + desiredTimestampStrings.flatMap { timestampString => + Seq("America/Los_Angeles", "Europe/Berlin", "UTC").map { tzId => + TimeZone.setDefault(TimeZone.getTimeZone(tzId)) + val timestamp = Timestamp.valueOf(timestampString) + (timestampString, tzId) -> timestamp.getTime() + } + }.toMap + } finally { + TimeZone.setDefault(originalTz) + } + } + + private def createRawData(spark: SparkSession): Dataset[(String, Timestamp)] = { + import spark.implicits._ + val df = desiredTimestampStrings.toDF("display") + // this will get the millis corresponding to the display time given the current *session* + // timezone. + df.withColumn("ts", expr("cast(display as timestamp)")).as[(String, Timestamp)] + } + + private def testWriteTablesWithTimezone( + spark: SparkSession, + baseTable: String, + explicitTz: Option[String], + sessionTzOpt: Option[String]) : Unit = { + val key = ParquetFileFormat.PARQUET_TIMEZONE_TABLE_PROPERTY + test(s"SPARK-12297: Write to Parquet tables with Timestamps; explicitTz = $explicitTz; " + + s"sessionTzOpt = $sessionTzOpt") { + + withTable(s"saveAsTable_$baseTable", s"insert_$baseTable", s"partitioned_ts_$baseTable") { + val sessionTzId = sessionTzOpt.getOrElse(TimeZone.getDefault().getID()) + // check that created tables have correct TBLPROPERTIES + val tblProperties = explicitTz.map { + tz => s"""TBLPROPERTIES ($key="$tz")""" + }.getOrElse("") + + val rawData = createRawData(spark) + // Check writing data out. + // We write data into our tables, and then check the raw parquet files to see whether + // the correct conversion was applied. + rawData.write.saveAsTable(s"saveAsTable_$baseTable") + checkHasTz(spark, s"saveAsTable_$baseTable", None) + spark.sql( + s"""CREATE TABLE insert_$baseTable ( + | display string, + | ts timestamp + | ) + | STORED AS PARQUET + | $tblProperties + """.stripMargin) + checkHasTz(spark, s"insert_$baseTable", explicitTz) + rawData.write.insertInto(s"insert_$baseTable") + // no matter what, roundtripping via the table should leave the data unchanged + val readFromTable = spark.table(s"insert_$baseTable").collect() + .map { row => (row.getAs[String](0), row.getAs[Timestamp](1)).toString() }.sorted + assert(readFromTable === rawData.collect().map(_.toString()).sorted) + + // Now we load the raw parquet data on disk, and check if it was adjusted correctly. + // Note that we only store the timezone in the table property, so when we read the + // data this way, we're bypassing all of the conversion logic, and reading the raw + // values in the parquet file. + val onDiskLocation = spark.sessionState.catalog + .getTableMetadata(TableIdentifier(s"insert_$baseTable")).location.getPath + // we test reading the data back with and without the vectorized reader, to make sure we + // haven't broken reading parquet from non-hive tables, with both readers. + Seq(false, true).foreach { vectorized => + spark.conf.set(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key, vectorized) + val readFromDisk = spark.read.parquet(onDiskLocation).collect() + val storageTzId = explicitTz.getOrElse(sessionTzId) + readFromDisk.foreach { row => + val displayTime = row.getAs[String](0) + val millis = row.getAs[Timestamp](1).getTime() + val expectedMillis = timestampTimezoneToMillis((displayTime, storageTzId)) + assert(expectedMillis === millis, s"Display time '$displayTime' was stored " + + s"incorrectly with sessionTz = ${sessionTzOpt}; Got $millis, expected " + + s"$expectedMillis (delta = ${millis - expectedMillis})") + } + } + + // check tables partitioned by timestamps. We don't compare the "raw" data in this case, + // since they are adjusted even when we bypass the hive table. + rawData.write.partitionBy("ts").saveAsTable(s"partitioned_ts_$baseTable") + val partitionDiskLocation = spark.sessionState.catalog + .getTableMetadata(TableIdentifier(s"partitioned_ts_$baseTable")).location.getPath + // no matter what mix of timezones we use, the dirs should specify the value with the + // same time we use for display. + val parts = new File(partitionDiskLocation).list().collect { + case name if name.startsWith("ts=") => URLDecoder.decode(name.stripPrefix("ts=")) + }.toSet + assert(parts === desiredTimestampStrings.toSet) + } + } + } + + private def testReadTablesWithTimezone( + spark: SparkSession, + baseTable: String, + explicitTz: Option[String], + sessionTzOpt: Option[String]): Unit = { + val key = ParquetFileFormat.PARQUET_TIMEZONE_TABLE_PROPERTY + test(s"SPARK-12297: Read from Parquet tables with Timestamps; explicitTz = $explicitTz; " + + s"sessionTzOpt = $sessionTzOpt") { + withTable(s"external_$baseTable", s"partitioned_$baseTable") { + // we intentionally save this data directly, without creating a table, so we can + // see that the data is read back differently depending on table properties. + // we'll save with adjusted millis, so that it should be the correct millis after reading + // back. + val rawData = createRawData(spark) + // to avoid closing over entire class + val timestampTimezoneToMillis = this.timestampTimezoneToMillis + import spark.implicits._ + val adjustedRawData = (explicitTz match { + case Some(tzId) => + rawData.map { case (displayTime, _) => + val storageMillis = timestampTimezoneToMillis((displayTime, tzId)) + (displayTime, new Timestamp(storageMillis)) + } + case _ => + rawData + }).withColumnRenamed("_1", "display").withColumnRenamed("_2", "ts") + withTempPath { basePath => + val unpartitionedPath = new File(basePath, "flat") + val partitionedPath = new File(basePath, "partitioned") + adjustedRawData.write.parquet(unpartitionedPath.getCanonicalPath) + val options = Map("path" -> unpartitionedPath.getCanonicalPath) ++ + explicitTz.map { tz => Map(key -> tz) }.getOrElse(Map()) + + spark.catalog.createTable( + tableName = s"external_$baseTable", + source = "parquet", + schema = new StructType().add("display", StringType).add("ts", TimestampType), + options = options + ) + + // also write out a partitioned table, to make sure we can access that correctly. + // add a column we can partition by (value doesn't particularly matter). + val partitionedData = adjustedRawData.withColumn("id", monotonicallyIncreasingId) + partitionedData.write.partitionBy("id") + .parquet(partitionedPath.getCanonicalPath) + // unfortunately, catalog.createTable() doesn't let us specify partitioning, so just use + // a "CREATE TABLE" stmt. + val tblOpts = explicitTz.map { tz => s"""TBLPROPERTIES ($key="$tz")""" }.getOrElse("") + spark.sql(s"""CREATE EXTERNAL TABLE partitioned_$baseTable ( + | display string, + | ts timestamp + |) + |PARTITIONED BY (id bigint) + |STORED AS parquet + |LOCATION 'file:${partitionedPath.getCanonicalPath}' + |$tblOpts + """.stripMargin) + spark.sql(s"msck repair table partitioned_$baseTable") + + for { + vectorized <- Seq(false, true) + partitioned <- Seq(false, true) + } { + withClue(s"vectorized = $vectorized; partitioned = $partitioned") { + spark.conf.set(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key, vectorized) + val sessionTz = sessionTzOpt.getOrElse(TimeZone.getDefault().getID()) + val table = if (partitioned) s"partitioned_$baseTable" else s"external_$baseTable" + val query = s"select display, cast(ts as string) as ts_as_string, ts " + + s"from $table" + val collectedFromExternal = spark.sql(query).collect() + assert( collectedFromExternal.size === 4) + collectedFromExternal.foreach { row => + val displayTime = row.getAs[String](0) + // the timestamp should still display the same, despite the changes in timezones + assert(displayTime === row.getAs[String](1).toString()) + // we'll also check that the millis behind the timestamp has the appropriate + // adjustments. + val millis = row.getAs[Timestamp](2).getTime() + val expectedMillis = timestampTimezoneToMillis((displayTime, sessionTz)) + val delta = millis - expectedMillis + val deltaHours = delta / (1000L * 60 * 60) + assert(millis === expectedMillis, s"Display time '$displayTime' did not have " + + s"correct millis: was $millis, expected $expectedMillis; delta = $delta " + + s"($deltaHours hours)") + } + + // Now test that the behavior is still correct even with a filter which could get + // pushed down into parquet. We don't need extra handling for pushed down + // predicates because (a) in ParquetFilters, we ignore TimestampType and (b) parquet + // does not read statistics from int96 fields, as they are unsigned. See + // scalastyle:off line.size.limit + // https://github.com/apache/parquet-mr/blob/2fd62ee4d524c270764e9b91dca72e5cf1a005b7/parquet-hadoop/src/main/java/org/apache/parquet/format/converter/ParquetMetadataConverter.java#L419 + // https://github.com/apache/parquet-mr/blob/2fd62ee4d524c270764e9b91dca72e5cf1a005b7/parquet-hadoop/src/main/java/org/apache/parquet/format/converter/ParquetMetadataConverter.java#L348 + // scalastyle:on line.size.limit + // + // Just to be defensive in case anything ever changes in parquet, this test checks + // the assumption on column stats, and also the end-to-end behavior. + + val hadoopConf = sparkContext.hadoopConfiguration + val fs = FileSystem.get(hadoopConf) + val parts = if (partitioned) { + val subdirs = fs.listStatus(new Path(partitionedPath.getCanonicalPath)) + .filter(_.getPath().getName().startsWith("id=")) + fs.listStatus(subdirs.head.getPath()) + .filter(_.getPath().getName().endsWith(".parquet")) + } else { + fs.listStatus(new Path(unpartitionedPath.getCanonicalPath)) + .filter(_.getPath().getName().endsWith(".parquet")) + } + // grab the meta data from the parquet file. The next section of asserts just make + // sure the test is configured correctly. + assert(parts.size == 1) + val oneFooter = ParquetFileReader.readFooter(hadoopConf, parts.head.getPath) + assert(oneFooter.getFileMetaData.getSchema.getColumns.size === 2) + assert(oneFooter.getFileMetaData.getSchema.getColumns.get(1).getType() === + PrimitiveTypeName.INT96) + val oneBlockMeta = oneFooter.getBlocks().get(0) + val oneBlockColumnMeta = oneBlockMeta.getColumns().get(1) + val columnStats = oneBlockColumnMeta.getStatistics + // This is the important assert. Column stats are written, but they are ignored + // when the data is read back as mentioned above, b/c int96 is unsigned. This + // assert makes sure this holds even if we change parquet versions (if eg. there + // were ever statistics even on unsigned columns). + assert(columnStats.isEmpty) + + // These queries should return the entire dataset, but if the predicates were + // applied to the raw values in parquet, they would incorrectly filter data out. + Seq( + ">" -> "2015-12-31 22:00:00", + "<" -> "2016-01-01 02:00:00" + ).foreach { case (comparison, value) => + val query = + s"select ts from $table where ts $comparison '$value'" + val countWithFilter = spark.sql(query).count() + assert(countWithFilter === 4, query) + } + } + } + } + } + } + } + + test("SPARK-12297: exception on bad timezone") { + val key = ParquetFileFormat.PARQUET_TIMEZONE_TABLE_PROPERTY + val badTzException = intercept[AnalysisException] { + spark.sql( + s"""CREATE TABLE bad_tz_table ( + | x int + | ) + | STORED AS PARQUET + | TBLPROPERTIES ($key="Blart Versenwald III") + """.stripMargin) + } + assert(badTzException.getMessage.contains("Blart Versenwald III")) + } } From c24bdaab5a234d18b273544cefc44cc4005bf8fc Mon Sep 17 00:00:00 2001 From: Felix Cheung Date: Sun, 7 May 2017 23:10:18 -0700 Subject: [PATCH 446/512] [SPARK-20626][SPARKR] address date test warning with timezone on windows ## What changes were proposed in this pull request? set timezone on windows ## How was this patch tested? unit test, AppVeyor Author: Felix Cheung Closes #17892 from felixcheung/rtimestamptest. --- R/pkg/inst/tests/testthat/test_sparkSQL.R | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index 0856bab5686c..f517ce671313 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -96,6 +96,10 @@ mockLinesMapType <- c("{\"name\":\"Bob\",\"info\":{\"age\":16,\"height\":176.5}} mapTypeJsonPath <- tempfile(pattern = "sparkr-test", fileext = ".tmp") writeLines(mockLinesMapType, mapTypeJsonPath) +if (.Platform$OS.type == "windows") { + Sys.setenv(TZ = "GMT") +} + test_that("calling sparkRSQL.init returns existing SQL context", { skip_on_cran() From 42cc6d13edbebb7c435ec47c0c12b445e05fdd49 Mon Sep 17 00:00:00 2001 From: sujith71955 Date: Sun, 7 May 2017 23:15:00 -0700 Subject: [PATCH 447/512] [SPARK-20380][SQL] Unable to set/unset table comment property using ALTER TABLE SET/UNSET TBLPROPERTIES ddl ### What changes were proposed in this pull request? Table comment was not getting set/unset using **ALTER TABLE SET/UNSET TBLPROPERTIES** query eg: ALTER TABLE table_with_comment SET TBLPROPERTIES("comment"= "modified comment) when user alter the table properties and adds/updates table comment,table comment which is a field of **CatalogTable** instance is not getting updated and old table comment if exists was shown to user, inorder to handle this issue, update the comment field value in **CatalogTable** with the newly added/modified comment along with other table level properties when user executes **ALTER TABLE SET TBLPROPERTIES** query. This pr has also taken care of unsetting the table comment when user executes query **ALTER TABLE UNSET TBLPROPERTIES** inorder to unset or remove table comment. eg: ALTER TABLE table_comment UNSET TBLPROPERTIES IF EXISTS ('comment') ### How was this patch tested? Added test cases as part of **SQLQueryTestSuite** for verifying table comment using desc formatted table query after adding/modifying table comment as part of **AlterTableSetPropertiesCommand** and unsetting the table comment using **AlterTableUnsetPropertiesCommand**. Author: sujith71955 Closes #17649 from sujith71955/alter_table_comment. --- .../catalyst/catalog/InMemoryCatalog.scala | 8 +- .../spark/sql/execution/command/ddl.scala | 12 +- .../describe-table-after-alter-table.sql | 29 ++++ .../describe-table-after-alter-table.sql.out | 161 ++++++++++++++++++ 4 files changed, 204 insertions(+), 6 deletions(-) create mode 100644 sql/core/src/test/resources/sql-tests/inputs/describe-table-after-alter-table.sql create mode 100644 sql/core/src/test/resources/sql-tests/results/describe-table-after-alter-table.sql.out diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala index 81dd8efc0015..8a5319bebe54 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala @@ -216,8 +216,8 @@ class InMemoryCatalog( } else { tableDefinition } - - catalog(db).tables.put(table, new TableDesc(tableWithLocation)) + val tableProp = tableWithLocation.properties.filter(_._1 != "comment") + catalog(db).tables.put(table, new TableDesc(tableWithLocation.copy(properties = tableProp))) } } @@ -298,7 +298,9 @@ class InMemoryCatalog( assert(tableDefinition.identifier.database.isDefined) val db = tableDefinition.identifier.database.get requireTableExists(db, tableDefinition.identifier.table) - catalog(db).tables(tableDefinition.identifier.table).table = tableDefinition + val updatedProperties = tableDefinition.properties.filter(kv => kv._1 != "comment") + val newTableDefinition = tableDefinition.copy(properties = updatedProperties) + catalog(db).tables(tableDefinition.identifier.table).table = newTableDefinition } override def alterTableSchema( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala index 55540563ef91..793fb9b79559 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala @@ -231,8 +231,12 @@ case class AlterTableSetPropertiesCommand( val catalog = sparkSession.sessionState.catalog val table = catalog.getTableMetadata(tableName) DDLUtils.verifyAlterTableType(catalog, table, isView) - // This overrides old properties - val newTable = table.copy(properties = table.properties ++ properties) + // This overrides old properties and update the comment parameter of CatalogTable + // with the newly added/modified comment since CatalogTable also holds comment as its + // direct property. + val newTable = table.copy( + properties = table.properties ++ properties, + comment = properties.get("comment")) catalog.alterTable(newTable) Seq.empty[Row] } @@ -267,8 +271,10 @@ case class AlterTableUnsetPropertiesCommand( } } } + // If comment is in the table property, we reset it to None + val tableComment = if (propKeys.contains("comment")) None else table.properties.get("comment") val newProperties = table.properties.filter { case (k, _) => !propKeys.contains(k) } - val newTable = table.copy(properties = newProperties) + val newTable = table.copy(properties = newProperties, comment = tableComment) catalog.alterTable(newTable) Seq.empty[Row] } diff --git a/sql/core/src/test/resources/sql-tests/inputs/describe-table-after-alter-table.sql b/sql/core/src/test/resources/sql-tests/inputs/describe-table-after-alter-table.sql new file mode 100644 index 000000000000..69bff6656c43 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/describe-table-after-alter-table.sql @@ -0,0 +1,29 @@ +CREATE TABLE table_with_comment (a STRING, b INT, c STRING, d STRING) USING parquet COMMENT 'added'; + +DESC FORMATTED table_with_comment; + +-- ALTER TABLE BY MODIFYING COMMENT +ALTER TABLE table_with_comment SET TBLPROPERTIES("comment"= "modified comment", "type"= "parquet"); + +DESC FORMATTED table_with_comment; + +-- DROP TEST TABLE +DROP TABLE table_with_comment; + +-- CREATE TABLE WITHOUT COMMENT +CREATE TABLE table_comment (a STRING, b INT) USING parquet; + +DESC FORMATTED table_comment; + +-- ALTER TABLE BY ADDING COMMENT +ALTER TABLE table_comment SET TBLPROPERTIES(comment = "added comment"); + +DESC formatted table_comment; + +-- ALTER UNSET PROPERTIES COMMENT +ALTER TABLE table_comment UNSET TBLPROPERTIES IF EXISTS ('comment'); + +DESC FORMATTED table_comment; + +-- DROP TEST TABLE +DROP TABLE table_comment; diff --git a/sql/core/src/test/resources/sql-tests/results/describe-table-after-alter-table.sql.out b/sql/core/src/test/resources/sql-tests/results/describe-table-after-alter-table.sql.out new file mode 100644 index 000000000000..1cc11c475bc4 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/describe-table-after-alter-table.sql.out @@ -0,0 +1,161 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 12 + + +-- !query 0 +CREATE TABLE table_with_comment (a STRING, b INT, c STRING, d STRING) USING parquet COMMENT 'added' +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +DESC FORMATTED table_with_comment +-- !query 1 schema +struct +-- !query 1 output +# col_name data_type comment +a string +b int +c string +d string + +# Detailed Table Information +Database default +Table table_with_comment +Created [not included in comparison] +Last Access [not included in comparison] +Type MANAGED +Provider parquet +Comment added +Location [not included in comparison]sql/core/spark-warehouse/table_with_comment + + +-- !query 2 +ALTER TABLE table_with_comment SET TBLPROPERTIES("comment"= "modified comment", "type"= "parquet") +-- !query 2 schema +struct<> +-- !query 2 output + + + +-- !query 3 +DESC FORMATTED table_with_comment +-- !query 3 schema +struct +-- !query 3 output +# col_name data_type comment +a string +b int +c string +d string + +# Detailed Table Information +Database default +Table table_with_comment +Created [not included in comparison] +Last Access [not included in comparison] +Type MANAGED +Provider parquet +Comment modified comment +Properties [type=parquet] +Location [not included in comparison]sql/core/spark-warehouse/table_with_comment + + +-- !query 4 +DROP TABLE table_with_comment +-- !query 4 schema +struct<> +-- !query 4 output + + + +-- !query 5 +CREATE TABLE table_comment (a STRING, b INT) USING parquet +-- !query 5 schema +struct<> +-- !query 5 output + + + +-- !query 6 +DESC FORMATTED table_comment +-- !query 6 schema +struct +-- !query 6 output +# col_name data_type comment +a string +b int + +# Detailed Table Information +Database default +Table table_comment +Created [not included in comparison] +Last Access [not included in comparison] +Type MANAGED +Provider parquet +Location [not included in comparison]sql/core/spark-warehouse/table_comment + + +-- !query 7 +ALTER TABLE table_comment SET TBLPROPERTIES(comment = "added comment") +-- !query 7 schema +struct<> +-- !query 7 output + + + +-- !query 8 +DESC formatted table_comment +-- !query 8 schema +struct +-- !query 8 output +# col_name data_type comment +a string +b int + +# Detailed Table Information +Database default +Table table_comment +Created [not included in comparison] +Last Access [not included in comparison] +Type MANAGED +Provider parquet +Comment added comment +Location [not included in comparison]sql/core/spark-warehouse/table_comment + + +-- !query 9 +ALTER TABLE table_comment UNSET TBLPROPERTIES IF EXISTS ('comment') +-- !query 9 schema +struct<> +-- !query 9 output + + + +-- !query 10 +DESC FORMATTED table_comment +-- !query 10 schema +struct +-- !query 10 output +# col_name data_type comment +a string +b int + +# Detailed Table Information +Database default +Table table_comment +Created [not included in comparison] +Last Access [not included in comparison] +Type MANAGED +Provider parquet +Location [not included in comparison]sql/core/spark-warehouse/table_comment + + +-- !query 11 +DROP TABLE table_comment +-- !query 11 schema +struct<> +-- !query 11 output + From 2fdaeb52bbe2ed1a9127ac72917286e505303c85 Mon Sep 17 00:00:00 2001 From: Wayne Zhang Date: Sun, 7 May 2017 23:16:30 -0700 Subject: [PATCH 448/512] [SPARKR][DOC] fix typo in vignettes ## What changes were proposed in this pull request? Fix typo in vignettes Author: Wayne Zhang Closes #17884 from actuaryzhang/typo. --- R/pkg/vignettes/sparkr-vignettes.Rmd | 36 ++++++++++++++-------------- 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/R/pkg/vignettes/sparkr-vignettes.Rmd b/R/pkg/vignettes/sparkr-vignettes.Rmd index d38ec4f1b6f3..49f4ab8f146a 100644 --- a/R/pkg/vignettes/sparkr-vignettes.Rmd +++ b/R/pkg/vignettes/sparkr-vignettes.Rmd @@ -65,7 +65,7 @@ We can view the first few rows of the `SparkDataFrame` by `head` or `showDF` fun head(carsDF) ``` -Common data processing operations such as `filter`, `select` are supported on the `SparkDataFrame`. +Common data processing operations such as `filter` and `select` are supported on the `SparkDataFrame`. ```{r} carsSubDF <- select(carsDF, "model", "mpg", "hp") carsSubDF <- filter(carsSubDF, carsSubDF$hp >= 200) @@ -379,7 +379,7 @@ out <- dapply(carsSubDF, function(x) { x <- cbind(x, x$mpg * 1.61) }, schema) head(collect(out)) ``` -Like `dapply`, apply a function to each partition of a `SparkDataFrame` and collect the result back. The output of function should be a `data.frame`, but no schema is required in this case. Note that `dapplyCollect` can fail if the output of UDF run on all the partition cannot be pulled to the driver and fit in driver memory. +Like `dapply`, `dapplyCollect` can apply a function to each partition of a `SparkDataFrame` and collect the result back. The output of the function should be a `data.frame`, but no schema is required in this case. Note that `dapplyCollect` can fail if the output of the UDF on all partitions cannot be pulled into the driver's memory. ```{r} out <- dapplyCollect( @@ -405,7 +405,7 @@ result <- gapply( head(arrange(result, "max_mpg", decreasing = TRUE)) ``` -Like gapply, `gapplyCollect` applies a function to each partition of a `SparkDataFrame` and collect the result back to R `data.frame`. The output of the function should be a `data.frame` but no schema is required in this case. Note that `gapplyCollect` can fail if the output of UDF run on all the partition cannot be pulled to the driver and fit in driver memory. +Like `gapply`, `gapplyCollect` can apply a function to each partition of a `SparkDataFrame` and collect the result back to R `data.frame`. The output of the function should be a `data.frame` but no schema is required in this case. Note that `gapplyCollect` can fail if the output of the UDF on all partitions cannot be pulled into the driver's memory. ```{r} result <- gapplyCollect( @@ -458,20 +458,20 @@ options(ops) ### SQL Queries -A `SparkDataFrame` can also be registered as a temporary view in Spark SQL and that allows you to run SQL queries over its data. The sql function enables applications to run SQL queries programmatically and returns the result as a `SparkDataFrame`. +A `SparkDataFrame` can also be registered as a temporary view in Spark SQL so that one can run SQL queries over its data. The sql function enables applications to run SQL queries programmatically and returns the result as a `SparkDataFrame`. ```{r} people <- read.df(paste0(sparkR.conf("spark.home"), "/examples/src/main/resources/people.json"), "json") ``` -Register this SparkDataFrame as a temporary view. +Register this `SparkDataFrame` as a temporary view. ```{r} createOrReplaceTempView(people, "people") ``` -SQL statements can be run by using the sql method. +SQL statements can be run using the sql method. ```{r} teenagers <- sql("SELECT name FROM people WHERE age >= 13 AND age <= 19") head(teenagers) @@ -780,7 +780,7 @@ head(predict(isoregModel, newDF)) `spark.gbt` fits a [gradient-boosted tree](https://en.wikipedia.org/wiki/Gradient_boosting) classification or regression model on a `SparkDataFrame`. Users can call `summary` to get a summary of the fitted model, `predict` to make predictions, and `write.ml`/`read.ml` to save/load fitted models. -Similar to the random forest example above, we use the `longley` dataset to train a gradient-boosted tree and make predictions: +We use the `longley` dataset to train a gradient-boosted tree and make predictions: ```{r, warning=FALSE} df <- createDataFrame(longley) @@ -820,7 +820,7 @@ head(select(fitted, "Class", "prediction")) `spark.gaussianMixture` fits multivariate [Gaussian Mixture Model](https://en.wikipedia.org/wiki/Mixture_model#Multivariate_Gaussian_mixture_model) (GMM) against a `SparkDataFrame`. [Expectation-Maximization](https://en.wikipedia.org/wiki/Expectation%E2%80%93maximization_algorithm) (EM) is used to approximate the maximum likelihood estimator (MLE) of the model. -We use a simulated example to demostrate the usage. +We use a simulated example to demonstrate the usage. ```{r} X1 <- data.frame(V1 = rnorm(4), V2 = rnorm(4)) X2 <- data.frame(V1 = rnorm(6, 3), V2 = rnorm(6, 4)) @@ -851,9 +851,9 @@ head(select(kmeansPredictions, "model", "mpg", "hp", "wt", "prediction"), n = 20 * Topics and documents both exist in a feature space, where feature vectors are vectors of word counts (bag of words). -* Rather than estimating a clustering using a traditional distance, LDA uses a function based on a statistical model of how text documents are generated. +* Rather than clustering using a traditional distance, LDA uses a function based on a statistical model of how text documents are generated. -To use LDA, we need to specify a `features` column in `data` where each entry represents a document. There are two type options for the column: +To use LDA, we need to specify a `features` column in `data` where each entry represents a document. There are two options for the column: * character string: This can be a string of the whole document. It will be parsed automatically. Additional stop words can be added in `customizedStopWords`. @@ -901,7 +901,7 @@ perplexity `spark.als` learns latent factors in [collaborative filtering](https://en.wikipedia.org/wiki/Recommender_system#Collaborative_filtering) via [alternating least squares](http://dl.acm.org/citation.cfm?id=1608614). -There are multiple options that can be configured in `spark.als`, including `rank`, `reg`, `nonnegative`. For a complete list, refer to the help file. +There are multiple options that can be configured in `spark.als`, including `rank`, `reg`, and `nonnegative`. For a complete list, refer to the help file. ```{r, eval=FALSE} ratings <- list(list(0, 0, 4.0), list(0, 1, 2.0), list(1, 1, 3.0), list(1, 2, 4.0), @@ -981,7 +981,7 @@ testSummary ### Model Persistence -The following example shows how to save/load an ML model by SparkR. +The following example shows how to save/load an ML model in SparkR. ```{r} t <- as.data.frame(Titanic) training <- createDataFrame(t) @@ -1079,19 +1079,19 @@ There are three main object classes in SparkR you may be working with. + `sdf` stores a reference to the corresponding Spark Dataset in the Spark JVM backend. + `env` saves the meta-information of the object such as `isCached`. -It can be created by data import methods or by transforming an existing `SparkDataFrame`. We can manipulate `SparkDataFrame` by numerous data processing functions and feed that into machine learning algorithms. + It can be created by data import methods or by transforming an existing `SparkDataFrame`. We can manipulate `SparkDataFrame` by numerous data processing functions and feed that into machine learning algorithms. -* `Column`: an S4 class representing column of `SparkDataFrame`. The slot `jc` saves a reference to the corresponding Column object in the Spark JVM backend. +* `Column`: an S4 class representing a column of `SparkDataFrame`. The slot `jc` saves a reference to the corresponding `Column` object in the Spark JVM backend. -It can be obtained from a `SparkDataFrame` by `$` operator, `df$col`. More often, it is used together with other functions, for example, with `select` to select particular columns, with `filter` and constructed conditions to select rows, with aggregation functions to compute aggregate statistics for each group. + It can be obtained from a `SparkDataFrame` by `$` operator, e.g., `df$col`. More often, it is used together with other functions, for example, with `select` to select particular columns, with `filter` and constructed conditions to select rows, with aggregation functions to compute aggregate statistics for each group. -* `GroupedData`: an S4 class representing grouped data created by `groupBy` or by transforming other `GroupedData`. Its `sgd` slot saves a reference to a RelationalGroupedDataset object in the backend. +* `GroupedData`: an S4 class representing grouped data created by `groupBy` or by transforming other `GroupedData`. Its `sgd` slot saves a reference to a `RelationalGroupedDataset` object in the backend. -This is often an intermediate object with group information and followed up by aggregation operations. + This is often an intermediate object with group information and followed up by aggregation operations. ### Architecture -A complete description of architecture can be seen in reference, in particular the paper *SparkR: Scaling R Programs with Spark*. +A complete description of architecture can be seen in the references, in particular the paper *SparkR: Scaling R Programs with Spark*. Under the hood of SparkR is Spark SQL engine. This avoids the overheads of running interpreted R code, and the optimized SQL execution engine in Spark uses structural information about data and computation flow to perform a bunch of optimizations to speed up the computation. From 0f820e2b6c507dc4156703862ce65e598ca41cca Mon Sep 17 00:00:00 2001 From: liuxian Date: Mon, 8 May 2017 10:00:58 +0100 Subject: [PATCH 449/512] [SPARK-20519][SQL][CORE] Modify to prevent some possible runtime exceptions Signed-off-by: liuxian ## What changes were proposed in this pull request? When the input parameter is null, may be a runtime exception occurs ## How was this patch tested? Existing unit tests Author: liuxian Closes #17796 from 10110346/wip_lx_0428. --- .../scala/org/apache/spark/api/python/PythonRDD.scala | 2 +- .../scala/org/apache/spark/deploy/DeployMessage.scala | 8 ++++---- .../scala/org/apache/spark/deploy/master/Master.scala | 2 +- .../org/apache/spark/deploy/master/MasterArguments.scala | 4 ++-- .../org/apache/spark/deploy/master/WorkerInfo.scala | 2 +- .../scala/org/apache/spark/deploy/worker/Worker.scala | 2 +- .../org/apache/spark/deploy/worker/WorkerArguments.scala | 4 ++-- .../main/scala/org/apache/spark/executor/Executor.scala | 2 +- .../scala/org/apache/spark/storage/BlockManagerId.scala | 2 +- core/src/main/scala/org/apache/spark/util/RpcUtils.scala | 2 +- core/src/main/scala/org/apache/spark/util/Utils.scala | 9 +++++---- .../deploy/mesos/MesosClusterDispatcherArguments.scala | 2 +- 12 files changed, 21 insertions(+), 20 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index b0dd2fc187ba..fb0405b1a69c 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -879,7 +879,7 @@ private[spark] class PythonAccumulatorV2( private val serverPort: Int) extends CollectionAccumulator[Array[Byte]] { - Utils.checkHost(serverHost, "Expected hostname") + Utils.checkHost(serverHost) val bufferSize = SparkEnv.get.conf.getInt("spark.buffer.size", 65536) diff --git a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala index ac09c6c497f8..b5cb3f0a0f9d 100644 --- a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala @@ -43,7 +43,7 @@ private[deploy] object DeployMessages { memory: Int, workerWebUiUrl: String) extends DeployMessage { - Utils.checkHost(host, "Required hostname") + Utils.checkHost(host) assert (port > 0) } @@ -131,7 +131,7 @@ private[deploy] object DeployMessages { // TODO(matei): replace hostPort with host case class ExecutorAdded(id: Int, workerId: String, hostPort: String, cores: Int, memory: Int) { - Utils.checkHostPort(hostPort, "Required hostport") + Utils.checkHostPort(hostPort) } case class ExecutorUpdated(id: Int, state: ExecutorState, message: Option[String], @@ -183,7 +183,7 @@ private[deploy] object DeployMessages { completedDrivers: Array[DriverInfo], status: MasterState) { - Utils.checkHost(host, "Required hostname") + Utils.checkHost(host) assert (port > 0) def uri: String = "spark://" + host + ":" + port @@ -201,7 +201,7 @@ private[deploy] object DeployMessages { drivers: List[DriverRunner], finishedDrivers: List[DriverRunner], masterUrl: String, cores: Int, memory: Int, coresUsed: Int, memoryUsed: Int, masterWebUiUrl: String) { - Utils.checkHost(host, "Required hostname") + Utils.checkHost(host) assert (port > 0) } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index 816bf37e39fe..e061939623cb 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -80,7 +80,7 @@ private[deploy] class Master( private val waitingDrivers = new ArrayBuffer[DriverInfo] private var nextDriverNumber = 0 - Utils.checkHost(address.host, "Expected hostname") + Utils.checkHost(address.host) private val masterMetricsSystem = MetricsSystem.createMetricsSystem("master", conf, securityMgr) private val applicationMetricsSystem = MetricsSystem.createMetricsSystem("applications", conf, diff --git a/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala b/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala index c63793c16dce..615d2533cf08 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala @@ -60,12 +60,12 @@ private[master] class MasterArguments(args: Array[String], conf: SparkConf) exte @tailrec private def parse(args: List[String]): Unit = args match { case ("--ip" | "-i") :: value :: tail => - Utils.checkHost(value, "ip no longer supported, please use hostname " + value) + Utils.checkHost(value) host = value parse(tail) case ("--host" | "-h") :: value :: tail => - Utils.checkHost(value, "Please use hostname " + value) + Utils.checkHost(value) host = value parse(tail) diff --git a/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala index 4e20c10fd142..c87d6e24b78c 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala @@ -32,7 +32,7 @@ private[spark] class WorkerInfo( val webUiAddress: String) extends Serializable { - Utils.checkHost(host, "Expected hostname") + Utils.checkHost(host) assert (port > 0) @transient var executors: mutable.HashMap[String, ExecutorDesc] = _ // executorId => info diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index 00b9d1af373d..34e3a4c020c8 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -55,7 +55,7 @@ private[deploy] class Worker( private val host = rpcEnv.address.host private val port = rpcEnv.address.port - Utils.checkHost(host, "Expected hostname") + Utils.checkHost(host) assert (port > 0) // A scheduled executor used to send messages at the specified time. diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala index 777020d4d5c8..bd07d342e04a 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala @@ -68,12 +68,12 @@ private[worker] class WorkerArguments(args: Array[String], conf: SparkConf) { @tailrec private def parse(args: List[String]): Unit = args match { case ("--ip" | "-i") :: value :: tail => - Utils.checkHost(value, "ip no longer supported, please use hostname " + value) + Utils.checkHost(value) host = value parse(tail) case ("--host" | "-h") :: value :: tail => - Utils.checkHost(value, "Please use hostname " + value) + Utils.checkHost(value) host = value parse(tail) diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 51b6c373c4da..3bc47b670305 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -71,7 +71,7 @@ private[spark] class Executor( private val conf = env.conf // No ip or host:port - just hostname - Utils.checkHost(executorHostname, "Expected executed slave to be a hostname") + Utils.checkHost(executorHostname) // must not have port specified. assert (0 == Utils.parseHostPort(executorHostname)._2) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala index c37a3604d28f..2c3da0ee85e0 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala @@ -46,7 +46,7 @@ class BlockManagerId private ( def executorId: String = executorId_ if (null != host_) { - Utils.checkHost(host_, "Expected hostname") + Utils.checkHost(host_) assert (port_ > 0) } diff --git a/core/src/main/scala/org/apache/spark/util/RpcUtils.scala b/core/src/main/scala/org/apache/spark/util/RpcUtils.scala index 46a5cb2cff5a..e5cccf39f945 100644 --- a/core/src/main/scala/org/apache/spark/util/RpcUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/RpcUtils.scala @@ -28,7 +28,7 @@ private[spark] object RpcUtils { def makeDriverRef(name: String, conf: SparkConf, rpcEnv: RpcEnv): RpcEndpointRef = { val driverHost: String = conf.get("spark.driver.host", "localhost") val driverPort: Int = conf.getInt("spark.driver.port", 7077) - Utils.checkHost(driverHost, "Expected hostname") + Utils.checkHost(driverHost) rpcEnv.setupEndpointRef(RpcAddress(driverHost, driverPort), name) } diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 4d37db96dfc3..edfe22979232 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -937,12 +937,13 @@ private[spark] object Utils extends Logging { customHostname.getOrElse(InetAddresses.toUriString(localIpAddress)) } - def checkHost(host: String, message: String = "") { - assert(host.indexOf(':') == -1, message) + def checkHost(host: String) { + assert(host != null && host.indexOf(':') == -1, s"Expected hostname (not IP) but got $host") } - def checkHostPort(hostPort: String, message: String = "") { - assert(hostPort.indexOf(':') != -1, message) + def checkHostPort(hostPort: String) { + assert(hostPort != null && hostPort.indexOf(':') != -1, + s"Expected host and port but got $hostPort") } // Typically, this will be of order of number of nodes in cluster diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArguments.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArguments.scala index ef08502ec8dd..ddea762fdb91 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArguments.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArguments.scala @@ -59,7 +59,7 @@ private[mesos] class MesosClusterDispatcherArguments(args: Array[String], conf: @tailrec private def parse(args: List[String]): Unit = args match { case ("--host" | "-h") :: value :: tail => - Utils.checkHost(value, "Please use hostname " + value) + Utils.checkHost(value) host = value parse(tail) From 15526653a93a32cde3c9ea0c0e68e35622b0a590 Mon Sep 17 00:00:00 2001 From: Xianyang Liu Date: Mon, 8 May 2017 17:33:47 +0800 Subject: [PATCH 450/512] [SPARK-19956][CORE] Optimize a location order of blocks with topology information ## What changes were proposed in this pull request? When call the method getLocations of BlockManager, we only compare the data block host. Random selection for non-local data blocks, this may cause the selected data block to be in a different rack. So in this patch to increase the sort of the rack. ## How was this patch tested? New test case. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Xianyang Liu Closes #17300 from ConeyLiu/blockmanager. --- .../apache/spark/storage/BlockManager.scala | 11 +++++-- .../spark/storage/BlockManagerSuite.scala | 31 +++++++++++++++++-- 2 files changed, 37 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 3219969bcd06..33ce30c58e1a 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -612,12 +612,19 @@ private[spark] class BlockManager( /** * Return a list of locations for the given block, prioritizing the local machine since - * multiple block managers can share the same host. + * multiple block managers can share the same host, followed by hosts on the same rack. */ private def getLocations(blockId: BlockId): Seq[BlockManagerId] = { val locs = Random.shuffle(master.getLocations(blockId)) val (preferredLocs, otherLocs) = locs.partition { loc => blockManagerId.host == loc.host } - preferredLocs ++ otherLocs + blockManagerId.topologyInfo match { + case None => preferredLocs ++ otherLocs + case Some(_) => + val (sameRackLocs, differentRackLocs) = otherLocs.partition { + loc => blockManagerId.topologyInfo == loc.topologyInfo + } + preferredLocs ++ sameRackLocs ++ differentRackLocs + } } /** diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index a8b960489983..1e7bcdb6740f 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -496,8 +496,8 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE assert(list2DiskGet.get.readMethod === DataReadMethod.Disk) } - test("optimize a location order of blocks") { - val localHost = Utils.localHostName() + test("optimize a location order of blocks without topology information") { + val localHost = "localhost" val otherHost = "otherHost" val bmMaster = mock(classOf[BlockManagerMaster]) val bmId1 = BlockManagerId("id1", localHost, 1) @@ -508,7 +508,32 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE val blockManager = makeBlockManager(128, "exec", bmMaster) val getLocations = PrivateMethod[Seq[BlockManagerId]]('getLocations) val locations = blockManager invokePrivate getLocations(BroadcastBlockId(0)) - assert(locations.map(_.host).toSet === Set(localHost, localHost, otherHost)) + assert(locations.map(_.host) === Seq(localHost, localHost, otherHost)) + } + + test("optimize a location order of blocks with topology information") { + val localHost = "localhost" + val otherHost = "otherHost" + val localRack = "localRack" + val otherRack = "otherRack" + + val bmMaster = mock(classOf[BlockManagerMaster]) + val bmId1 = BlockManagerId("id1", localHost, 1, Some(localRack)) + val bmId2 = BlockManagerId("id2", localHost, 2, Some(localRack)) + val bmId3 = BlockManagerId("id3", otherHost, 3, Some(otherRack)) + val bmId4 = BlockManagerId("id4", otherHost, 4, Some(otherRack)) + val bmId5 = BlockManagerId("id5", otherHost, 5, Some(localRack)) + when(bmMaster.getLocations(mc.any[BlockId])) + .thenReturn(Seq(bmId1, bmId2, bmId5, bmId3, bmId4)) + + val blockManager = makeBlockManager(128, "exec", bmMaster) + blockManager.blockManagerId = + BlockManagerId(SparkContext.DRIVER_IDENTIFIER, localHost, 1, Some(localRack)) + val getLocations = PrivateMethod[Seq[BlockManagerId]]('getLocations) + val locations = blockManager invokePrivate getLocations(BroadcastBlockId(0)) + assert(locations.map(_.host) === Seq(localHost, localHost, otherHost, otherHost, otherHost)) + assert(locations.flatMap(_.topologyInfo) + === Seq(localRack, localRack, localRack, otherRack, otherRack)) } test("SPARK-9591: getRemoteBytes from another location when Exception throw") { From 58518d070777fc0665c4d02bad8adf910807df98 Mon Sep 17 00:00:00 2001 From: Nick Pentreath Date: Mon, 8 May 2017 12:45:00 +0200 Subject: [PATCH 451/512] [SPARK-20596][ML][TEST] Consolidate and improve ALS recommendAll test cases Existing test cases for `recommendForAllX` methods (added in [SPARK-19535](https://issues.apache.org/jira/browse/SPARK-19535)) test `k < num items` and `k = num items`. Technically we should also test that `k > num items` returns the same results as `k = num items`. ## How was this patch tested? Updated existing unit tests. Author: Nick Pentreath Closes #17860 from MLnick/SPARK-20596-als-rec-tests. --- .../spark/ml/recommendation/ALSSuite.scala | 63 ++++++++----------- 1 file changed, 25 insertions(+), 38 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala index 7574af3d77ea..9d31e792633c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala @@ -671,58 +671,45 @@ class ALSSuite .setItemCol("item") } - test("recommendForAllUsers with k < num_items") { - val topItems = getALSModel.recommendForAllUsers(2) - assert(topItems.count() == 3) - assert(topItems.columns.contains("user")) - - val expected = Map( - 0 -> Array((3, 54f), (4, 44f)), - 1 -> Array((3, 39f), (5, 33f)), - 2 -> Array((3, 51f), (5, 45f)) - ) - checkRecommendations(topItems, expected, "item") - } - - test("recommendForAllUsers with k = num_items") { - val topItems = getALSModel.recommendForAllUsers(4) - assert(topItems.count() == 3) - assert(topItems.columns.contains("user")) - + test("recommendForAllUsers with k <, = and > num_items") { + val model = getALSModel + val numUsers = model.userFactors.count + val numItems = model.itemFactors.count val expected = Map( 0 -> Array((3, 54f), (4, 44f), (5, 42f), (6, 28f)), 1 -> Array((3, 39f), (5, 33f), (4, 26f), (6, 16f)), 2 -> Array((3, 51f), (5, 45f), (4, 30f), (6, 18f)) ) - checkRecommendations(topItems, expected, "item") - } - test("recommendForAllItems with k < num_users") { - val topUsers = getALSModel.recommendForAllItems(2) - assert(topUsers.count() == 4) - assert(topUsers.columns.contains("item")) - - val expected = Map( - 3 -> Array((0, 54f), (2, 51f)), - 4 -> Array((0, 44f), (2, 30f)), - 5 -> Array((2, 45f), (0, 42f)), - 6 -> Array((0, 28f), (2, 18f)) - ) - checkRecommendations(topUsers, expected, "user") + Seq(2, 4, 6).foreach { k => + val n = math.min(k, numItems).toInt + val expectedUpToN = expected.mapValues(_.slice(0, n)) + val topItems = model.recommendForAllUsers(k) + assert(topItems.count() == numUsers) + assert(topItems.columns.contains("user")) + checkRecommendations(topItems, expectedUpToN, "item") + } } - test("recommendForAllItems with k = num_users") { - val topUsers = getALSModel.recommendForAllItems(3) - assert(topUsers.count() == 4) - assert(topUsers.columns.contains("item")) - + test("recommendForAllItems with k <, = and > num_users") { + val model = getALSModel + val numUsers = model.userFactors.count + val numItems = model.itemFactors.count val expected = Map( 3 -> Array((0, 54f), (2, 51f), (1, 39f)), 4 -> Array((0, 44f), (2, 30f), (1, 26f)), 5 -> Array((2, 45f), (0, 42f), (1, 33f)), 6 -> Array((0, 28f), (2, 18f), (1, 16f)) ) - checkRecommendations(topUsers, expected, "user") + + Seq(2, 3, 4).foreach { k => + val n = math.min(k, numUsers).toInt + val expectedUpToN = expected.mapValues(_.slice(0, n)) + val topUsers = getALSModel.recommendForAllItems(k) + assert(topUsers.count() == numItems) + assert(topUsers.columns.contains("item")) + checkRecommendations(topUsers, expectedUpToN, "user") + } } private def checkRecommendations( From aeb2ecc0cd898f5352df0a04be1014b02ea3e20e Mon Sep 17 00:00:00 2001 From: Xianyang Liu Date: Mon, 8 May 2017 10:25:24 -0700 Subject: [PATCH 452/512] [SPARK-20621][DEPLOY] Delete deprecated config parameter in 'spark-env.sh' ## What changes were proposed in this pull request? Currently, `spark.executor.instances` is deprecated in `spark-env.sh`, because we suggest config it in `spark-defaults.conf` or other config file. And also this parameter is useless even if you set it in `spark-env.sh`, so remove it in this patch. ## How was this patch tested? Existing tests. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Xianyang Liu Closes #17881 from ConeyLiu/deprecatedParam. --- conf/spark-env.sh.template | 1 - .../org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala | 5 +---- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/conf/spark-env.sh.template b/conf/spark-env.sh.template index 94bd2c477a35..b7c985ace69c 100755 --- a/conf/spark-env.sh.template +++ b/conf/spark-env.sh.template @@ -34,7 +34,6 @@ # Options read in YARN client mode # - HADOOP_CONF_DIR, to point Spark towards Hadoop configuration files -# - SPARK_EXECUTOR_INSTANCES, Number of executors to start (Default: 2) # - SPARK_EXECUTOR_CORES, Number of cores for the executors (Default: 1). # - SPARK_EXECUTOR_MEMORY, Memory per Executor (e.g. 1000M, 2G) (Default: 1G) # - SPARK_DRIVER_MEMORY, Memory for Driver (e.g. 1000M, 2G) (Default: 1G) diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala index 93578855122c..0fc994d629cc 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala @@ -280,10 +280,7 @@ object YarnSparkHadoopUtil { initialNumExecutors } else { - val targetNumExecutors = - sys.env.get("SPARK_EXECUTOR_INSTANCES").map(_.toInt).getOrElse(numExecutors) - // System property can override environment variable. - conf.get(EXECUTOR_INSTANCES).getOrElse(targetNumExecutors) + conf.get(EXECUTOR_INSTANCES).getOrElse(numExecutors) } } } From 829cd7b8b70e65a91aa66e6d626bd45f18e0ad97 Mon Sep 17 00:00:00 2001 From: jerryshao Date: Mon, 8 May 2017 14:27:56 -0700 Subject: [PATCH 453/512] [SPARK-20605][CORE][YARN][MESOS] Deprecate not used AM and executor port configuration ## What changes were proposed in this pull request? After SPARK-10997, client mode Netty RpcEnv doesn't require to start server, so port configurations are not used any more, here propose to remove these two configurations: "spark.executor.port" and "spark.am.port". ## How was this patch tested? Existing UTs. Author: jerryshao Closes #17866 from jerryshao/SPARK-20605. --- .../scala/org/apache/spark/SparkConf.scala | 4 ++- .../scala/org/apache/spark/SparkEnv.scala | 14 +++----- .../CoarseGrainedExecutorBackend.scala | 5 ++- docs/running-on-mesos.md | 2 +- docs/running-on-yarn.md | 7 ---- .../spark/executor/MesosExecutorBackend.scala | 3 +- .../cluster/mesos/MesosSchedulerUtils.scala | 2 +- .../mesos/MesosSchedulerUtilsSuite.scala | 34 +++++-------------- .../spark/deploy/yarn/ApplicationMaster.scala | 3 +- .../org/apache/spark/deploy/yarn/config.scala | 5 --- 10 files changed, 22 insertions(+), 57 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index 2a2ce0504dbb..956724b14bba 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -579,7 +579,9 @@ private[spark] object SparkConf extends Logging { "are no longer accepted. To specify the equivalent now, one may use '64k'."), DeprecatedConfig("spark.rpc", "2.0", "Not used any more."), DeprecatedConfig("spark.scheduler.executorTaskBlacklistTime", "2.1.0", - "Please use the new blacklisting options, spark.blacklist.*") + "Please use the new blacklisting options, spark.blacklist.*"), + DeprecatedConfig("spark.yarn.am.port", "2.0.0", "Not used any more"), + DeprecatedConfig("spark.executor.port", "2.0.0", "Not used any more") ) Map(configs.map { cfg => (cfg.key -> cfg) } : _*) diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index f4a59f069a5f..3196c1ece15e 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -177,7 +177,7 @@ object SparkEnv extends Logging { SparkContext.DRIVER_IDENTIFIER, bindAddress, advertiseAddress, - port, + Option(port), isLocal, numCores, ioEncryptionKey, @@ -194,7 +194,6 @@ object SparkEnv extends Logging { conf: SparkConf, executorId: String, hostname: String, - port: Int, numCores: Int, ioEncryptionKey: Option[Array[Byte]], isLocal: Boolean): SparkEnv = { @@ -203,7 +202,7 @@ object SparkEnv extends Logging { executorId, hostname, hostname, - port, + None, isLocal, numCores, ioEncryptionKey @@ -220,7 +219,7 @@ object SparkEnv extends Logging { executorId: String, bindAddress: String, advertiseAddress: String, - port: Int, + port: Option[Int], isLocal: Boolean, numUsableCores: Int, ioEncryptionKey: Option[Array[Byte]], @@ -243,17 +242,12 @@ object SparkEnv extends Logging { } val systemName = if (isDriver) driverSystemName else executorSystemName - val rpcEnv = RpcEnv.create(systemName, bindAddress, advertiseAddress, port, conf, + val rpcEnv = RpcEnv.create(systemName, bindAddress, advertiseAddress, port.getOrElse(-1), conf, securityManager, clientMode = !isDriver) // Figure out which port RpcEnv actually bound to in case the original port is 0 or occupied. - // In the non-driver case, the RPC env's address may be null since it may not be listening - // for incoming connections. if (isDriver) { conf.set("spark.driver.port", rpcEnv.address.port.toString) - } else if (rpcEnv.address != null) { - conf.set("spark.executor.port", rpcEnv.address.port.toString) - logInfo(s"Setting spark.executor.port to: ${rpcEnv.address.port.toString}") } // Create an instance of the class with the given name, possibly initializing it with our conf diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index b2b26ee107c0..a2f1aa22b006 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -191,11 +191,10 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { // Bootstrap to fetch the driver's Spark properties. val executorConf = new SparkConf - val port = executorConf.getInt("spark.executor.port", 0) val fetcher = RpcEnv.create( "driverPropsFetcher", hostname, - port, + -1, executorConf, new SecurityManager(executorConf), clientMode = true) @@ -221,7 +220,7 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { } val env = SparkEnv.createExecutorEnv( - driverConf, executorId, hostname, port, cores, cfg.ioEncryptionKey, isLocal = false) + driverConf, executorId, hostname, cores, cfg.ioEncryptionKey, isLocal = false) env.rpcEnv.setupEndpoint("Executor", new CoarseGrainedExecutorBackend( env.rpcEnv, driverUrl, executorId, hostname, cores, userClassPath, env)) diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md index 314a806edf39..c1344ad99a7d 100644 --- a/docs/running-on-mesos.md +++ b/docs/running-on-mesos.md @@ -209,7 +209,7 @@ provide such guarantees on the offer stream. In this mode spark executors will honor port allocation if such is provided from the user. Specifically if the user defines -`spark.executor.port` or `spark.blockManager.port` in Spark configuration, +`spark.blockManager.port` in Spark configuration, the mesos scheduler will check the available offers for a valid port range containing the port numbers. If no such range is available it will not launch any task. If no restriction is imposed on port numbers by the diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index e9ddaa76a797..2d56123028f2 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -239,13 +239,6 @@ To use a custom metrics.properties for the application master and executors, upd Same as spark.yarn.driver.memoryOverhead, but for the YARN Application Master in client mode. - - spark.yarn.am.port - (random) - - Port for the YARN Application Master to listen on. In YARN client mode, this is used to communicate between the Spark driver running on a gateway and the YARN Application Master running on YARN. In YARN cluster mode, this is used for the dynamic executor feature, where it handles the kill from the scheduler backend. - - spark.yarn.queue default diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala index a086ec7ea2da..61bfa27a84fd 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala @@ -74,9 +74,8 @@ private[spark] class MesosExecutorBackend val properties = Utils.deserialize[Array[(String, String)]](executorInfo.getData.toByteArray) ++ Seq[(String, String)](("spark.app.id", frameworkInfo.getId.getValue)) val conf = new SparkConf(loadDefaults = true).setAll(properties) - val port = conf.getInt("spark.executor.port", 0) val env = SparkEnv.createExecutorEnv( - conf, executorId, slaveInfo.getHostname, port, cpusPerTask, None, isLocal = false) + conf, executorId, slaveInfo.getHostname, cpusPerTask, None, isLocal = false) executor = new Executor( executorId, diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala index 9d81025a3016..062ed1f93fa5 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala @@ -438,7 +438,7 @@ trait MesosSchedulerUtils extends Logging { } } - val managedPortNames = List("spark.executor.port", BLOCK_MANAGER_PORT.key) + val managedPortNames = List(BLOCK_MANAGER_PORT.key) /** * The values of the non-zero ports to be used by the executor process. diff --git a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtilsSuite.scala b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtilsSuite.scala index ec47ab153177..5d4bf6d082c4 100644 --- a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtilsSuite.scala +++ b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtilsSuite.scala @@ -179,40 +179,25 @@ class MesosSchedulerUtilsSuite extends SparkFunSuite with Matchers with MockitoS test("Port reservation is done correctly with user specified ports only") { val conf = new SparkConf() - conf.set("spark.executor.port", "3000" ) conf.set(BLOCK_MANAGER_PORT, 4000) val portResource = createTestPortResource((3000, 5000), Some("my_role")) val (resourcesLeft, resourcesToBeUsed) = utils - .partitionPortResources(List(3000, 4000), List(portResource)) - resourcesToBeUsed.length shouldBe 2 + .partitionPortResources(List(4000), List(portResource)) + resourcesToBeUsed.length shouldBe 1 val portsToUse = getRangesFromResources(resourcesToBeUsed).map{r => r._1}.toArray - portsToUse.length shouldBe 2 - arePortsEqual(portsToUse, Array(3000L, 4000L)) shouldBe true + portsToUse.length shouldBe 1 + arePortsEqual(portsToUse, Array(4000L)) shouldBe true val portRangesToBeUsed = rangesResourcesToTuple(resourcesToBeUsed) - val expectedUSed = Array((3000L, 3000L), (4000L, 4000L)) + val expectedUSed = Array((4000L, 4000L)) arePortsEqual(portRangesToBeUsed.toArray, expectedUSed) shouldBe true } - test("Port reservation is done correctly with some user specified ports (spark.executor.port)") { - val conf = new SparkConf() - conf.set("spark.executor.port", "3100" ) - val portResource = createTestPortResource((3000, 5000), Some("my_role")) - - val (resourcesLeft, resourcesToBeUsed) = utils - .partitionPortResources(List(3100), List(portResource)) - - val portsToUse = getRangesFromResources(resourcesToBeUsed).map{r => r._1} - - portsToUse.length shouldBe 1 - portsToUse.contains(3100) shouldBe true - } - test("Port reservation is done correctly with all random ports") { val conf = new SparkConf() val portResource = createTestPortResource((3000L, 5000L), Some("my_role")) @@ -226,21 +211,20 @@ class MesosSchedulerUtilsSuite extends SparkFunSuite with Matchers with MockitoS test("Port reservation is done correctly with user specified ports only - multiple ranges") { val conf = new SparkConf() - conf.set("spark.executor.port", "2100" ) conf.set("spark.blockManager.port", "4000") val portResourceList = List(createTestPortResource((3000, 5000), Some("my_role")), createTestPortResource((2000, 2500), Some("other_role"))) val (resourcesLeft, resourcesToBeUsed) = utils - .partitionPortResources(List(2100, 4000), portResourceList) + .partitionPortResources(List(4000), portResourceList) val portsToUse = getRangesFromResources(resourcesToBeUsed).map{r => r._1} - portsToUse.length shouldBe 2 + portsToUse.length shouldBe 1 val portsRangesLeft = rangesResourcesToTuple(resourcesLeft) val portRangesToBeUsed = rangesResourcesToTuple(resourcesToBeUsed) - val expectedUsed = Array((2100L, 2100L), (4000L, 4000L)) + val expectedUsed = Array((4000L, 4000L)) - arePortsEqual(portsToUse.toArray, Array(2100L, 4000L)) shouldBe true + arePortsEqual(portsToUse.toArray, Array(4000L)) shouldBe true arePortsEqual(portRangesToBeUsed.toArray, expectedUsed) shouldBe true } diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 864c834d110f..6da2c0b5f330 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -429,8 +429,7 @@ private[spark] class ApplicationMaster( } private def runExecutorLauncher(securityMgr: SecurityManager): Unit = { - val port = sparkConf.get(AM_PORT) - rpcEnv = RpcEnv.create("sparkYarnAM", Utils.localHostName, port, sparkConf, securityMgr, + rpcEnv = RpcEnv.create("sparkYarnAM", Utils.localHostName, -1, sparkConf, securityMgr, clientMode = true) val driverRef = waitForSparkDriver() addAmIpFilter() diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala index d8c96c35ca71..d4108caab28c 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala @@ -40,11 +40,6 @@ package object config { .timeConf(TimeUnit.MILLISECONDS) .createOptional - private[spark] val AM_PORT = - ConfigBuilder("spark.yarn.am.port") - .intConf - .createWithDefault(0) - private[spark] val EXECUTOR_ATTEMPT_FAILURE_VALIDITY_INTERVAL_MS = ConfigBuilder("spark.yarn.executor.failuresValidityInterval") .doc("Interval after which Executor failures will be considered independent and not " + From 2abfee18b6511482b916c36f00bf3abf68a59e19 Mon Sep 17 00:00:00 2001 From: Hossein Date: Mon, 8 May 2017 14:48:11 -0700 Subject: [PATCH 454/512] [SPARK-20661][SPARKR][TEST] SparkR tableNames() test fails ## What changes were proposed in this pull request? Cleaning existing temp tables before running tableNames tests ## How was this patch tested? SparkR Unit tests Author: Hossein Closes #17903 from falaki/SPARK-20661. --- R/pkg/inst/tests/testthat/test_sparkSQL.R | 2 ++ 1 file changed, 2 insertions(+) diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index f517ce671313..ab6888ea34fd 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -677,6 +677,8 @@ test_that("jsonRDD() on a RDD with json string", { }) test_that("test tableNames and tables", { + # Making sure there are no registered temp tables from previous tests + suppressWarnings(sapply(tableNames(), function(tname) { dropTempTable(tname) })) df <- read.json(jsonPath) createOrReplaceTempView(df, "table1") expect_equal(length(tableNames()), 1) From b952b44af4d243f1e3ad88bccf4af7d04df3fc81 Mon Sep 17 00:00:00 2001 From: Felix Cheung Date: Mon, 8 May 2017 22:49:40 -0700 Subject: [PATCH 455/512] [SPARK-20661][SPARKR][TEST][FOLLOWUP] SparkR tableNames() test fails ## What changes were proposed in this pull request? Change it to check for relative count like in this test https://github.com/apache/spark/blame/master/R/pkg/inst/tests/testthat/test_sparkSQL.R#L3355 for catalog APIs ## How was this patch tested? unit tests, this needs to combine with another commit with SQL change to check Author: Felix Cheung Closes #17905 from felixcheung/rtabletests. --- R/pkg/inst/tests/testthat/test_sparkSQL.R | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index ab6888ea34fd..19aa61e9a56c 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -677,26 +677,27 @@ test_that("jsonRDD() on a RDD with json string", { }) test_that("test tableNames and tables", { - # Making sure there are no registered temp tables from previous tests - suppressWarnings(sapply(tableNames(), function(tname) { dropTempTable(tname) })) + count <- count(listTables()) + df <- read.json(jsonPath) createOrReplaceTempView(df, "table1") - expect_equal(length(tableNames()), 1) - expect_equal(length(tableNames("default")), 1) + expect_equal(length(tableNames()), count + 1) + expect_equal(length(tableNames("default")), count + 1) + tables <- listTables() - expect_equal(count(tables), 1) + expect_equal(count(tables), count + 1) expect_equal(count(tables()), count(tables)) expect_true("tableName" %in% colnames(tables())) expect_true(all(c("tableName", "database", "isTemporary") %in% colnames(tables()))) suppressWarnings(registerTempTable(df, "table2")) tables <- listTables() - expect_equal(count(tables), 2) + expect_equal(count(tables), count + 2) suppressWarnings(dropTempTable("table1")) expect_true(dropTempView("table2")) tables <- listTables() - expect_equal(count(tables), 0) + expect_equal(count(tables), count + 0) }) test_that( From 8079424763c2043264f30a6898ce964379bd9b56 Mon Sep 17 00:00:00 2001 From: Peng Date: Tue, 9 May 2017 10:05:49 +0200 Subject: [PATCH 456/512] [SPARK-11968][MLLIB] Optimize MLLIB ALS recommendForAll The recommendForAll of MLLIB ALS is very slow. GC is a key problem of the current method. The task use the following code to keep temp result: val output = new Array[(Int, (Int, Double))](m*n) m = n = 4096 (default value, no method to set) so output is about 4k * 4k * (4 + 4 + 8) = 256M. This is a large memory and cause serious GC problem, and it is frequently OOM. Actually, we don't need to save all the temp result. Support we recommend topK (topK is about 10, or 20) product for each user, we only need 4k * topK * (4 + 4 + 8) memory to save the temp result. The Test Environment: 3 workers: each work 10 core, each work 30G memory, each work 1 executor. The Data: User 480,000, and Item 17,000 BlockSize: 1024 2048 4096 8192 Old method: 245s 332s 488s OOM This solution: 121s 118s 117s 120s The existing UT. Author: Peng Author: Peng Meng Closes #17742 from mpjlu/OptimizeAls. --- .../MatrixFactorizationModel.scala | 81 ++++++++++++------- 1 file changed, 50 insertions(+), 31 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala index 23045fa2b686..d45866c016d9 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala @@ -39,6 +39,7 @@ import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.rdd.RDD import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.storage.StorageLevel +import org.apache.spark.util.BoundedPriorityQueue /** * Model representing the result of matrix factorization. @@ -274,46 +275,64 @@ object MatrixFactorizationModel extends Loader[MatrixFactorizationModel] { srcFeatures: RDD[(Int, Array[Double])], dstFeatures: RDD[(Int, Array[Double])], num: Int): RDD[(Int, Array[(Int, Double)])] = { - val srcBlocks = blockify(rank, srcFeatures) - val dstBlocks = blockify(rank, dstFeatures) - val ratings = srcBlocks.cartesian(dstBlocks).flatMap { - case ((srcIds, srcFactors), (dstIds, dstFactors)) => - val m = srcIds.length - val n = dstIds.length - val ratings = srcFactors.transpose.multiply(dstFactors) - val output = new Array[(Int, (Int, Double))](m * n) - var k = 0 - ratings.foreachActive { (i, j, r) => - output(k) = (srcIds(i), (dstIds(j), r)) - k += 1 + val srcBlocks = blockify(srcFeatures) + val dstBlocks = blockify(dstFeatures) + /** + * The previous approach used for computing top-k recommendations aimed to group + * individual factor vectors into blocks, so that Level 3 BLAS operations (gemm) could + * be used for efficiency. However, this causes excessive GC pressure due to the large + * arrays required for intermediate result storage, as well as a high sensitivity to the + * block size used. + * The following approach still groups factors into blocks, but instead computes the + * top-k elements per block, using a simple dot product (instead of gemm) and an efficient + * [[BoundedPriorityQueue]]. This avoids any large intermediate data structures and results + * in significantly reduced GC pressure as well as shuffle data, which far outweighs + * any cost incurred from not using Level 3 BLAS operations. + */ + val ratings = srcBlocks.cartesian(dstBlocks).flatMap { case (srcIter, dstIter) => + val m = srcIter.size + val n = math.min(dstIter.size, num) + val output = new Array[(Int, (Int, Double))](m * n) + var j = 0 + val pq = new BoundedPriorityQueue[(Int, Double)](n)(Ordering.by(_._2)) + srcIter.foreach { case (srcId, srcFactor) => + dstIter.foreach { case (dstId, dstFactor) => + /* + * The below code is equivalent to + * `val score = blas.ddot(rank, srcFactor, 1, dstFactor, 1)` + * This handwritten version is as or more efficient as BLAS calls in this case. + */ + var score: Double = 0 + var k = 0 + while (k < rank) { + score += srcFactor(k) * dstFactor(k) + k += 1 + } + pq += dstId -> score + } + val pqIter = pq.iterator + var i = 0 + while (i < n) { + output(j + i) = (srcId, pqIter.next()) + i += 1 } - output.toSeq + j += n + pq.clear() + } + output.toSeq } ratings.topByKey(num)(Ordering.by(_._2)) } /** - * Blockifies features to use Level-3 BLAS. + * Blockifies features to improve the efficiency of cartesian product + * TODO: SPARK-20443 - expose blockSize as a param? */ private def blockify( - rank: Int, - features: RDD[(Int, Array[Double])]): RDD[(Array[Int], DenseMatrix)] = { - val blockSize = 4096 // TODO: tune the block size - val blockStorage = rank * blockSize + features: RDD[(Int, Array[Double])], + blockSize: Int = 4096): RDD[Seq[(Int, Array[Double])]] = { features.mapPartitions { iter => - iter.grouped(blockSize).map { grouped => - val ids = mutable.ArrayBuilder.make[Int] - ids.sizeHint(blockSize) - val factors = mutable.ArrayBuilder.make[Double] - factors.sizeHint(blockStorage) - var i = 0 - grouped.foreach { case (id, factor) => - ids += id - factors ++= factor - i += 1 - } - (ids.result(), new DenseMatrix(rank, i, factors.result())) - } + iter.grouped(blockSize) } } From 10b00abadf4a3473332eef996db7b66f491316f2 Mon Sep 17 00:00:00 2001 From: Nick Pentreath Date: Tue, 9 May 2017 10:13:15 +0200 Subject: [PATCH 457/512] [SPARK-20587][ML] Improve performance of ML ALS recommendForAll This PR is a `DataFrame` version of #17742 for [SPARK-11968](https://issues.apache.org/jira/browse/SPARK-11968), for improving the performance of `recommendAll` methods. ## How was this patch tested? Existing unit tests. Author: Nick Pentreath Closes #17845 from MLnick/ml-als-perf. --- .../apache/spark/ml/recommendation/ALS.scala | 71 +++++++++++++++++-- 1 file changed, 64 insertions(+), 7 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index 1562bf1beb7e..d626f0459967 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -45,7 +45,7 @@ import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel -import org.apache.spark.util.Utils +import org.apache.spark.util.{BoundedPriorityQueue, Utils} import org.apache.spark.util.collection.{OpenHashMap, OpenHashSet, SortDataFormat, Sorter} import org.apache.spark.util.random.XORShiftRandom @@ -356,6 +356,19 @@ class ALSModel private[ml] ( /** * Makes recommendations for all users (or items). + * + * Note: the previous approach used for computing top-k recommendations + * used a cross-join followed by predicting a score for each row of the joined dataset. + * However, this results in exploding the size of intermediate data. While Spark SQL makes it + * relatively efficient, the approach implemented here is significantly more efficient. + * + * This approach groups factors into blocks and computes the top-k elements per block, + * using a simple dot product (instead of gemm) and an efficient [[BoundedPriorityQueue]]. + * It then computes the global top-k by aggregating the per block top-k elements with + * a [[TopByKeyAggregator]]. This significantly reduces the size of intermediate and shuffle data. + * This is the DataFrame equivalent to the approach used in + * [[org.apache.spark.mllib.recommendation.MatrixFactorizationModel]]. + * * @param srcFactors src factors for which to generate recommendations * @param dstFactors dst factors used to make recommendations * @param srcOutputColumn name of the column for the source ID in the output DataFrame @@ -372,11 +385,43 @@ class ALSModel private[ml] ( num: Int): DataFrame = { import srcFactors.sparkSession.implicits._ - val ratings = srcFactors.crossJoin(dstFactors) - .select( - srcFactors("id"), - dstFactors("id"), - predict(srcFactors("features"), dstFactors("features"))) + val srcFactorsBlocked = blockify(srcFactors.as[(Int, Array[Float])]) + val dstFactorsBlocked = blockify(dstFactors.as[(Int, Array[Float])]) + val ratings = srcFactorsBlocked.crossJoin(dstFactorsBlocked) + .as[(Seq[(Int, Array[Float])], Seq[(Int, Array[Float])])] + .flatMap { case (srcIter, dstIter) => + val m = srcIter.size + val n = math.min(dstIter.size, num) + val output = new Array[(Int, Int, Float)](m * n) + var j = 0 + val pq = new BoundedPriorityQueue[(Int, Float)](num)(Ordering.by(_._2)) + srcIter.foreach { case (srcId, srcFactor) => + dstIter.foreach { case (dstId, dstFactor) => + /* + * The below code is equivalent to + * `val score = blas.sdot(rank, srcFactor, 1, dstFactor, 1)` + * This handwritten version is as or more efficient as BLAS calls in this case. + */ + var score = 0.0f + var k = 0 + while (k < rank) { + score += srcFactor(k) * dstFactor(k) + k += 1 + } + pq += dstId -> score + } + val pqIter = pq.iterator + var i = 0 + while (i < n) { + val (dstId, score) = pqIter.next() + output(j + i) = (srcId, dstId, score) + i += 1 + } + j += n + pq.clear() + } + output.toSeq + } // We'll force the IDs to be Int. Unfortunately this converts IDs to Int in the output. val topKAggregator = new TopByKeyAggregator[Int, Int, Float](num, Ordering.by(_._2)) val recs = ratings.as[(Int, Int, Float)].groupByKey(_._1).agg(topKAggregator.toColumn) @@ -387,8 +432,20 @@ class ALSModel private[ml] ( .add(dstOutputColumn, IntegerType) .add("rating", FloatType) ) - recs.select($"id" as srcOutputColumn, $"recommendations" cast arrayType) + recs.select($"id".as(srcOutputColumn), $"recommendations".cast(arrayType)) } + + /** + * Blockifies factors to improve the efficiency of cross join + * TODO: SPARK-20443 - expose blockSize as a param? + */ + private def blockify( + factors: Dataset[(Int, Array[Float])], + blockSize: Int = 4096): Dataset[Seq[(Int, Array[Float])]] = { + import factors.sparkSession.implicits._ + factors.mapPartitions(_.grouped(blockSize)) + } + } @Since("1.6.0") From be53a78352ae7c70d8a07d0df24574b3e3129b4a Mon Sep 17 00:00:00 2001 From: Jon McLean Date: Tue, 9 May 2017 09:47:50 +0100 Subject: [PATCH 458/512] [SPARK-20615][ML][TEST] SparseVector.argmax throws IndexOutOfBoundsException ## What changes were proposed in this pull request? Added a check for for the number of defined values. Previously the argmax function assumed that at least one value was defined if the vector size was greater than zero. ## How was this patch tested? Tests were added to the existing VectorsSuite to cover this case. Author: Jon McLean Closes #17877 from jonmclean/vectorArgmaxIndexBug. --- .../main/scala/org/apache/spark/ml/linalg/Vectors.scala | 2 ++ .../scala/org/apache/spark/ml/linalg/VectorsSuite.scala | 7 +++++++ .../main/scala/org/apache/spark/mllib/linalg/Vectors.scala | 2 ++ .../scala/org/apache/spark/mllib/linalg/VectorsSuite.scala | 7 +++++++ 4 files changed, 18 insertions(+) diff --git a/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala b/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala index 8e166ba0ff51..3fbc0958a0f1 100644 --- a/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala +++ b/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala @@ -657,6 +657,8 @@ class SparseVector @Since("2.0.0") ( override def argmax: Int = { if (size == 0) { -1 + } else if (numActives == 0) { + 0 } else { // Find the max active entry. var maxIdx = indices(0) diff --git a/mllib-local/src/test/scala/org/apache/spark/ml/linalg/VectorsSuite.scala b/mllib-local/src/test/scala/org/apache/spark/ml/linalg/VectorsSuite.scala index dfbdaf19d374..4cd91afd6d7f 100644 --- a/mllib-local/src/test/scala/org/apache/spark/ml/linalg/VectorsSuite.scala +++ b/mllib-local/src/test/scala/org/apache/spark/ml/linalg/VectorsSuite.scala @@ -125,6 +125,13 @@ class VectorsSuite extends SparkMLFunSuite { val vec8 = Vectors.sparse(5, Array(1, 2), Array(0.0, -1.0)) assert(vec8.argmax === 0) + + // Check for case when sparse vector is non-empty but the values are empty + val vec9 = Vectors.sparse(100, Array.empty[Int], Array.empty[Double]).asInstanceOf[SparseVector] + assert(vec9.argmax === 0) + + val vec10 = Vectors.sparse(1, Array.empty[Int], Array.empty[Double]).asInstanceOf[SparseVector] + assert(vec10.argmax === 0) } test("vector equals") { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index 723addc7150d..f063420bec14 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -846,6 +846,8 @@ class SparseVector @Since("1.0.0") ( override def argmax: Int = { if (size == 0) { -1 + } else if (numActives == 0) { + 0 } else { // Find the max active entry. var maxIdx = indices(0) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala index 71a3ceac1b94..6172cffee861 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala @@ -122,6 +122,13 @@ class VectorsSuite extends SparkFunSuite with Logging { val vec8 = Vectors.sparse(5, Array(1, 2), Array(0.0, -1.0)) assert(vec8.argmax === 0) + + // Check for case when sparse vector is non-empty but the values are empty + val vec9 = Vectors.sparse(100, Array.empty[Int], Array.empty[Double]).asInstanceOf[SparseVector] + assert(vec9.argmax === 0) + + val vec10 = Vectors.sparse(1, Array.empty[Int], Array.empty[Double]).asInstanceOf[SparseVector] + assert(vec10.argmax === 0) } test("vector equals") { From b8733e0ad9f5a700f385e210450fd2c10137293e Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Tue, 9 May 2017 17:30:37 +0800 Subject: [PATCH 459/512] [SPARK-20606][ML] ML 2.2 QA: Remove deprecated methods for ML ## What changes were proposed in this pull request? Remove ML methods we deprecated in 2.1. ## How was this patch tested? Existing tests. Author: Yanbo Liang Closes #17867 from yanboliang/spark-20606. --- .../DecisionTreeClassifier.scala | 18 +-- .../ml/classification/GBTClassifier.scala | 24 ++-- .../RandomForestClassifier.scala | 24 ++-- .../ml/regression/DecisionTreeRegressor.scala | 18 +-- .../spark/ml/regression/GBTRegressor.scala | 24 ++-- .../ml/regression/RandomForestRegressor.scala | 24 ++-- .../org/apache/spark/ml/tree/treeParams.scala | 105 ------------------ .../org/apache/spark/ml/util/ReadWrite.scala | 16 --- project/MimaExcludes.scala | 68 ++++++++++++ python/pyspark/ml/util.py | 32 ------ 10 files changed, 134 insertions(+), 219 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala index 9f60f0896ec5..5fb105c6aff6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala @@ -54,27 +54,27 @@ class DecisionTreeClassifier @Since("1.4.0") ( /** @group setParam */ @Since("1.4.0") - override def setMaxDepth(value: Int): this.type = set(maxDepth, value) + def setMaxDepth(value: Int): this.type = set(maxDepth, value) /** @group setParam */ @Since("1.4.0") - override def setMaxBins(value: Int): this.type = set(maxBins, value) + def setMaxBins(value: Int): this.type = set(maxBins, value) /** @group setParam */ @Since("1.4.0") - override def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value) + def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value) /** @group setParam */ @Since("1.4.0") - override def setMinInfoGain(value: Double): this.type = set(minInfoGain, value) + def setMinInfoGain(value: Double): this.type = set(minInfoGain, value) /** @group expertSetParam */ @Since("1.4.0") - override def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value) + def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value) /** @group expertSetParam */ @Since("1.4.0") - override def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value) + def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value) /** * Specifies how often to checkpoint the cached node IDs. @@ -86,15 +86,15 @@ class DecisionTreeClassifier @Since("1.4.0") ( * @group setParam */ @Since("1.4.0") - override def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value) + def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value) /** @group setParam */ @Since("1.4.0") - override def setImpurity(value: String): this.type = set(impurity, value) + def setImpurity(value: String): this.type = set(impurity, value) /** @group setParam */ @Since("1.6.0") - override def setSeed(value: Long): this.type = set(seed, value) + def setSeed(value: Long): this.type = set(seed, value) override protected def train(dataset: Dataset[_]): DecisionTreeClassificationModel = { val categoricalFeatures: Map[Int, Int] = diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index ade0960f87a0..263ed10f1985 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -70,27 +70,27 @@ class GBTClassifier @Since("1.4.0") ( /** @group setParam */ @Since("1.4.0") - override def setMaxDepth(value: Int): this.type = set(maxDepth, value) + def setMaxDepth(value: Int): this.type = set(maxDepth, value) /** @group setParam */ @Since("1.4.0") - override def setMaxBins(value: Int): this.type = set(maxBins, value) + def setMaxBins(value: Int): this.type = set(maxBins, value) /** @group setParam */ @Since("1.4.0") - override def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value) + def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value) /** @group setParam */ @Since("1.4.0") - override def setMinInfoGain(value: Double): this.type = set(minInfoGain, value) + def setMinInfoGain(value: Double): this.type = set(minInfoGain, value) /** @group expertSetParam */ @Since("1.4.0") - override def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value) + def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value) /** @group expertSetParam */ @Since("1.4.0") - override def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value) + def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value) /** * Specifies how often to checkpoint the cached node IDs. @@ -102,7 +102,7 @@ class GBTClassifier @Since("1.4.0") ( * @group setParam */ @Since("1.4.0") - override def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value) + def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value) /** * The impurity setting is ignored for GBT models. @@ -111,7 +111,7 @@ class GBTClassifier @Since("1.4.0") ( * @group setParam */ @Since("1.4.0") - override def setImpurity(value: String): this.type = { + def setImpurity(value: String): this.type = { logWarning("GBTClassifier.setImpurity should NOT be used") this } @@ -120,21 +120,21 @@ class GBTClassifier @Since("1.4.0") ( /** @group setParam */ @Since("1.4.0") - override def setSubsamplingRate(value: Double): this.type = set(subsamplingRate, value) + def setSubsamplingRate(value: Double): this.type = set(subsamplingRate, value) /** @group setParam */ @Since("1.4.0") - override def setSeed(value: Long): this.type = set(seed, value) + def setSeed(value: Long): this.type = set(seed, value) // Parameters from GBTParams: /** @group setParam */ @Since("1.4.0") - override def setMaxIter(value: Int): this.type = set(maxIter, value) + def setMaxIter(value: Int): this.type = set(maxIter, value) /** @group setParam */ @Since("1.4.0") - override def setStepSize(value: Double): this.type = set(stepSize, value) + def setStepSize(value: Double): this.type = set(stepSize, value) // Parameters from GBTClassifierParams: diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index ab4c23520928..441cfda89927 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -56,27 +56,27 @@ class RandomForestClassifier @Since("1.4.0") ( /** @group setParam */ @Since("1.4.0") - override def setMaxDepth(value: Int): this.type = set(maxDepth, value) + def setMaxDepth(value: Int): this.type = set(maxDepth, value) /** @group setParam */ @Since("1.4.0") - override def setMaxBins(value: Int): this.type = set(maxBins, value) + def setMaxBins(value: Int): this.type = set(maxBins, value) /** @group setParam */ @Since("1.4.0") - override def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value) + def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value) /** @group setParam */ @Since("1.4.0") - override def setMinInfoGain(value: Double): this.type = set(minInfoGain, value) + def setMinInfoGain(value: Double): this.type = set(minInfoGain, value) /** @group expertSetParam */ @Since("1.4.0") - override def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value) + def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value) /** @group expertSetParam */ @Since("1.4.0") - override def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value) + def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value) /** * Specifies how often to checkpoint the cached node IDs. @@ -88,31 +88,31 @@ class RandomForestClassifier @Since("1.4.0") ( * @group setParam */ @Since("1.4.0") - override def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value) + def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value) /** @group setParam */ @Since("1.4.0") - override def setImpurity(value: String): this.type = set(impurity, value) + def setImpurity(value: String): this.type = set(impurity, value) // Parameters from TreeEnsembleParams: /** @group setParam */ @Since("1.4.0") - override def setSubsamplingRate(value: Double): this.type = set(subsamplingRate, value) + def setSubsamplingRate(value: Double): this.type = set(subsamplingRate, value) /** @group setParam */ @Since("1.4.0") - override def setSeed(value: Long): this.type = set(seed, value) + def setSeed(value: Long): this.type = set(seed, value) // Parameters from RandomForestParams: /** @group setParam */ @Since("1.4.0") - override def setNumTrees(value: Int): this.type = set(numTrees, value) + def setNumTrees(value: Int): this.type = set(numTrees, value) /** @group setParam */ @Since("1.4.0") - override def setFeatureSubsetStrategy(value: String): this.type = + def setFeatureSubsetStrategy(value: String): this.type = set(featureSubsetStrategy, value) override protected def train(dataset: Dataset[_]): RandomForestClassificationModel = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala index 01c5cc1c7efa..c2b0358e8405 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala @@ -53,27 +53,27 @@ class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S // Override parameter setters from parent trait for Java API compatibility. /** @group setParam */ @Since("1.4.0") - override def setMaxDepth(value: Int): this.type = set(maxDepth, value) + def setMaxDepth(value: Int): this.type = set(maxDepth, value) /** @group setParam */ @Since("1.4.0") - override def setMaxBins(value: Int): this.type = set(maxBins, value) + def setMaxBins(value: Int): this.type = set(maxBins, value) /** @group setParam */ @Since("1.4.0") - override def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value) + def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value) /** @group setParam */ @Since("1.4.0") - override def setMinInfoGain(value: Double): this.type = set(minInfoGain, value) + def setMinInfoGain(value: Double): this.type = set(minInfoGain, value) /** @group expertSetParam */ @Since("1.4.0") - override def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value) + def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value) /** @group expertSetParam */ @Since("1.4.0") - override def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value) + def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value) /** * Specifies how often to checkpoint the cached node IDs. @@ -85,15 +85,15 @@ class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S * @group setParam */ @Since("1.4.0") - override def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value) + def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value) /** @group setParam */ @Since("1.4.0") - override def setImpurity(value: String): this.type = set(impurity, value) + def setImpurity(value: String): this.type = set(impurity, value) /** @group setParam */ @Since("1.6.0") - override def setSeed(value: Long): this.type = set(seed, value) + def setSeed(value: Long): this.type = set(seed, value) /** @group setParam */ @Since("2.0.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala index 08d175cb9444..8d9b519efb14 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala @@ -68,27 +68,27 @@ class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String) /** @group setParam */ @Since("1.4.0") - override def setMaxDepth(value: Int): this.type = set(maxDepth, value) + def setMaxDepth(value: Int): this.type = set(maxDepth, value) /** @group setParam */ @Since("1.4.0") - override def setMaxBins(value: Int): this.type = set(maxBins, value) + def setMaxBins(value: Int): this.type = set(maxBins, value) /** @group setParam */ @Since("1.4.0") - override def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value) + def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value) /** @group setParam */ @Since("1.4.0") - override def setMinInfoGain(value: Double): this.type = set(minInfoGain, value) + def setMinInfoGain(value: Double): this.type = set(minInfoGain, value) /** @group expertSetParam */ @Since("1.4.0") - override def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value) + def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value) /** @group expertSetParam */ @Since("1.4.0") - override def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value) + def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value) /** * Specifies how often to checkpoint the cached node IDs. @@ -100,7 +100,7 @@ class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String) * @group setParam */ @Since("1.4.0") - override def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value) + def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value) /** * The impurity setting is ignored for GBT models. @@ -109,7 +109,7 @@ class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String) * @group setParam */ @Since("1.4.0") - override def setImpurity(value: String): this.type = { + def setImpurity(value: String): this.type = { logWarning("GBTRegressor.setImpurity should NOT be used") this } @@ -118,21 +118,21 @@ class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String) /** @group setParam */ @Since("1.4.0") - override def setSubsamplingRate(value: Double): this.type = set(subsamplingRate, value) + def setSubsamplingRate(value: Double): this.type = set(subsamplingRate, value) /** @group setParam */ @Since("1.4.0") - override def setSeed(value: Long): this.type = set(seed, value) + def setSeed(value: Long): this.type = set(seed, value) // Parameters from GBTParams: /** @group setParam */ @Since("1.4.0") - override def setMaxIter(value: Int): this.type = set(maxIter, value) + def setMaxIter(value: Int): this.type = set(maxIter, value) /** @group setParam */ @Since("1.4.0") - override def setStepSize(value: Double): this.type = set(stepSize, value) + def setStepSize(value: Double): this.type = set(stepSize, value) // Parameters from GBTRegressorParams: diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala index a58da50fad97..7b9ddf6e9521 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala @@ -55,27 +55,27 @@ class RandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S /** @group setParam */ @Since("1.4.0") - override def setMaxDepth(value: Int): this.type = set(maxDepth, value) + def setMaxDepth(value: Int): this.type = set(maxDepth, value) /** @group setParam */ @Since("1.4.0") - override def setMaxBins(value: Int): this.type = set(maxBins, value) + def setMaxBins(value: Int): this.type = set(maxBins, value) /** @group setParam */ @Since("1.4.0") - override def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value) + def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value) /** @group setParam */ @Since("1.4.0") - override def setMinInfoGain(value: Double): this.type = set(minInfoGain, value) + def setMinInfoGain(value: Double): this.type = set(minInfoGain, value) /** @group expertSetParam */ @Since("1.4.0") - override def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value) + def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value) /** @group expertSetParam */ @Since("1.4.0") - override def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value) + def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value) /** * Specifies how often to checkpoint the cached node IDs. @@ -87,31 +87,31 @@ class RandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S * @group setParam */ @Since("1.4.0") - override def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value) + def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value) /** @group setParam */ @Since("1.4.0") - override def setImpurity(value: String): this.type = set(impurity, value) + def setImpurity(value: String): this.type = set(impurity, value) // Parameters from TreeEnsembleParams: /** @group setParam */ @Since("1.4.0") - override def setSubsamplingRate(value: Double): this.type = set(subsamplingRate, value) + def setSubsamplingRate(value: Double): this.type = set(subsamplingRate, value) /** @group setParam */ @Since("1.4.0") - override def setSeed(value: Long): this.type = set(seed, value) + def setSeed(value: Long): this.type = set(seed, value) // Parameters from RandomForestParams: /** @group setParam */ @Since("1.4.0") - override def setNumTrees(value: Int): this.type = set(numTrees, value) + def setNumTrees(value: Int): this.type = set(numTrees, value) /** @group setParam */ @Since("1.4.0") - override def setFeatureSubsetStrategy(value: String): this.type = + def setFeatureSubsetStrategy(value: String): this.type = set(featureSubsetStrategy, value) override protected def train(dataset: Dataset[_]): RandomForestRegressionModel = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala index cd1950bd76c0..5526d4d75bd7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala @@ -109,80 +109,24 @@ private[ml] trait DecisionTreeParams extends PredictorParams setDefault(maxDepth -> 5, maxBins -> 32, minInstancesPerNode -> 1, minInfoGain -> 0.0, maxMemoryInMB -> 256, cacheNodeIds -> false, checkpointInterval -> 10) - /** - * @deprecated This method is deprecated and will be removed in 2.2.0. - * @group setParam - */ - @deprecated("This method is deprecated and will be removed in 2.2.0.", "2.1.0") - def setMaxDepth(value: Int): this.type = set(maxDepth, value) - /** @group getParam */ final def getMaxDepth: Int = $(maxDepth) - /** - * @deprecated This method is deprecated and will be removed in 2.2.0. - * @group setParam - */ - @deprecated("This method is deprecated and will be removed in 2.2.0.", "2.1.0") - def setMaxBins(value: Int): this.type = set(maxBins, value) - /** @group getParam */ final def getMaxBins: Int = $(maxBins) - /** - * @deprecated This method is deprecated and will be removed in 2.2.0. - * @group setParam - */ - @deprecated("This method is deprecated and will be removed in 2.2.0.", "2.1.0") - def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value) - /** @group getParam */ final def getMinInstancesPerNode: Int = $(minInstancesPerNode) - /** - * @deprecated This method is deprecated and will be removed in 2.2.0. - * @group setParam - */ - @deprecated("This method is deprecated and will be removed in 2.2.0.", "2.1.0") - def setMinInfoGain(value: Double): this.type = set(minInfoGain, value) - /** @group getParam */ final def getMinInfoGain: Double = $(minInfoGain) - /** - * @deprecated This method is deprecated and will be removed in 2.2.0. - * @group setParam - */ - @deprecated("This method is deprecated and will be removed in 2.2.0.", "2.1.0") - def setSeed(value: Long): this.type = set(seed, value) - - /** - * @deprecated This method is deprecated and will be removed in 2.2.0. - * @group expertSetParam - */ - @deprecated("This method is deprecated and will be removed in 2.2.0.", "2.1.0") - def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value) - /** @group expertGetParam */ final def getMaxMemoryInMB: Int = $(maxMemoryInMB) - /** - * @deprecated This method is deprecated and will be removed in 2.2.0. - * @group expertSetParam - */ - @deprecated("This method is deprecated and will be removed in 2.2.0.", "2.1.0") - def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value) - /** @group expertGetParam */ final def getCacheNodeIds: Boolean = $(cacheNodeIds) - /** - * @deprecated This method is deprecated and will be removed in 2.2.0. - * @group setParam - */ - @deprecated("This method is deprecated and will be removed in 2.2.0.", "2.1.0") - def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value) - /** (private[ml]) Create a Strategy instance to use with the old API. */ private[ml] def getOldStrategy( categoricalFeatures: Map[Int, Int], @@ -225,13 +169,6 @@ private[ml] trait TreeClassifierParams extends Params { setDefault(impurity -> "gini") - /** - * @deprecated This method is deprecated and will be removed in 2.2.0. - * @group setParam - */ - @deprecated("This method is deprecated and will be removed in 2.2.0.", "2.1.0") - def setImpurity(value: String): this.type = set(impurity, value) - /** @group getParam */ final def getImpurity: String = $(impurity).toLowerCase(Locale.ROOT) @@ -276,13 +213,6 @@ private[ml] trait TreeRegressorParams extends Params { setDefault(impurity -> "variance") - /** - * @deprecated This method is deprecated and will be removed in 2.2.0. - * @group setParam - */ - @deprecated("This method is deprecated and will be removed in 2.2.0.", "2.1.0") - def setImpurity(value: String): this.type = set(impurity, value) - /** @group getParam */ final def getImpurity: String = $(impurity).toLowerCase(Locale.ROOT) @@ -338,13 +268,6 @@ private[ml] trait TreeEnsembleParams extends DecisionTreeParams { setDefault(subsamplingRate -> 1.0) - /** - * @deprecated This method is deprecated and will be removed in 2.2.0. - * @group setParam - */ - @deprecated("This method is deprecated and will be removed in 2.2.0.", "2.1.0") - def setSubsamplingRate(value: Double): this.type = set(subsamplingRate, value) - /** @group getParam */ final def getSubsamplingRate: Double = $(subsamplingRate) @@ -382,13 +305,6 @@ private[ml] trait RandomForestParams extends TreeEnsembleParams { setDefault(numTrees -> 20) - /** - * @deprecated This method is deprecated and will be removed in 2.2.0. - * @group setParam - */ - @deprecated("This method is deprecated and will be removed in 2.2.0.", "2.1.0") - def setNumTrees(value: Int): this.type = set(numTrees, value) - /** @group getParam */ final def getNumTrees: Int = $(numTrees) @@ -430,13 +346,6 @@ private[ml] trait RandomForestParams extends TreeEnsembleParams { setDefault(featureSubsetStrategy -> "auto") - /** - * @deprecated This method is deprecated and will be removed in 2.2.0. - * @group setParam - */ - @deprecated("This method is deprecated and will be removed in 2.2.0.", "2.1.0") - def setFeatureSubsetStrategy(value: String): this.type = set(featureSubsetStrategy, value) - /** @group getParam */ final def getFeatureSubsetStrategy: String = $(featureSubsetStrategy).toLowerCase(Locale.ROOT) } @@ -471,13 +380,6 @@ private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter { // final val validationTol: DoubleParam = new DoubleParam(this, "validationTol", "") // validationTol -> 1e-5 - /** - * @deprecated This method is deprecated and will be removed in 2.2.0. - * @group setParam - */ - @deprecated("This method is deprecated and will be removed in 2.2.0.", "2.1.0") - def setMaxIter(value: Int): this.type = set(maxIter, value) - /** * Param for Step size (a.k.a. learning rate) in interval (0, 1] for shrinking * the contribution of each estimator. @@ -491,13 +393,6 @@ private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter { /** @group getParam */ final def getStepSize: Double = $(stepSize) - /** - * @deprecated This method is deprecated and will be removed in 2.2.0. - * @group setParam - */ - @deprecated("This method is deprecated and will be removed in 2.2.0.", "2.1.0") - def setStepSize(value: Double): this.type = set(stepSize, value) - setDefault(maxIter -> 20, stepSize -> 0.1) /** (private[ml]) Create a BoostingStrategy instance to use with the old API. */ diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala index a8b80031faf8..f7e570fd5cc9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -42,16 +42,6 @@ import org.apache.spark.util.Utils private[util] sealed trait BaseReadWrite { private var optionSparkSession: Option[SparkSession] = None - /** - * Sets the Spark SQLContext to use for saving/loading. - */ - @Since("1.6.0") - @deprecated("Use session instead, This method will be removed in 2.2.0.", "2.0.0") - def context(sqlContext: SQLContext): this.type = { - optionSparkSession = Option(sqlContext.sparkSession) - this - } - /** * Sets the Spark Session to use for saving/loading. */ @@ -130,9 +120,6 @@ abstract class MLWriter extends BaseReadWrite with Logging { // override for Java compatibility override def session(sparkSession: SparkSession): this.type = super.session(sparkSession) - - // override for Java compatibility - override def context(sqlContext: SQLContext): this.type = super.session(sqlContext.sparkSession) } /** @@ -188,9 +175,6 @@ abstract class MLReader[T] extends BaseReadWrite { // override for Java compatibility override def session(sparkSession: SparkSession): this.type = super.session(sparkSession) - - // override for Java compatibility - override def context(sqlContext: SQLContext): this.type = super.session(sqlContext.sparkSession) } /** diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index d50882cb1917..d8b37aebb5d1 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -1005,6 +1005,74 @@ object MimaExcludes { ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.setFeatureSubsetStrategy"), ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.numTrees"), ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.setFeatureSubsetStrategy") + ) ++ Seq( + // [SPARK-20606] ML 2.2 QA: Remove deprecated methods for ML + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.DecisionTreeClassificationModel.setSeed"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.DecisionTreeClassificationModel.setMinInfoGain"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.DecisionTreeClassificationModel.setCacheNodeIds"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.DecisionTreeClassificationModel.setCheckpointInterval"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.DecisionTreeClassificationModel.setMaxDepth"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.DecisionTreeClassificationModel.setImpurity"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.DecisionTreeClassificationModel.setMaxMemoryInMB"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.DecisionTreeClassificationModel.setMaxBins"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.DecisionTreeClassificationModel.setMinInstancesPerNode"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.GBTClassificationModel.setSeed"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.GBTClassificationModel.setMinInfoGain"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.GBTClassificationModel.setSubsamplingRate"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.GBTClassificationModel.setMaxIter"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.GBTClassificationModel.setCacheNodeIds"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.GBTClassificationModel.setCheckpointInterval"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.GBTClassificationModel.setMaxDepth"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.GBTClassificationModel.setImpurity"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.GBTClassificationModel.setMaxMemoryInMB"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.GBTClassificationModel.setStepSize"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.GBTClassificationModel.setMaxBins"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.GBTClassificationModel.setMinInstancesPerNode"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.setSeed"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.setMinInfoGain"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.setSubsamplingRate"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.setCacheNodeIds"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.setCheckpointInterval"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.setMaxDepth"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.setImpurity"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.setMaxMemoryInMB"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.setFeatureSubsetStrategy"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.setMaxBins"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.setMinInstancesPerNode"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.DecisionTreeRegressionModel.setSeed"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.DecisionTreeRegressionModel.setMinInfoGain"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.DecisionTreeRegressionModel.setCacheNodeIds"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.DecisionTreeRegressionModel.setCheckpointInterval"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.DecisionTreeRegressionModel.setMaxDepth"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.DecisionTreeRegressionModel.setImpurity"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.DecisionTreeRegressionModel.setMaxMemoryInMB"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.DecisionTreeRegressionModel.setMaxBins"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.DecisionTreeRegressionModel.setMinInstancesPerNode"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.GBTRegressionModel.setSeed"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.GBTRegressionModel.setMinInfoGain"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.GBTRegressionModel.setSubsamplingRate"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.GBTRegressionModel.setMaxIter"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.GBTRegressionModel.setCacheNodeIds"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.GBTRegressionModel.setCheckpointInterval"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.GBTRegressionModel.setMaxDepth"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.GBTRegressionModel.setImpurity"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.GBTRegressionModel.setMaxMemoryInMB"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.GBTRegressionModel.setStepSize"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.GBTRegressionModel.setMaxBins"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.GBTRegressionModel.setMinInstancesPerNode"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.setSeed"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.setMinInfoGain"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.setSubsamplingRate"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.setCacheNodeIds"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.setCheckpointInterval"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.setMaxDepth"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.setImpurity"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.setMaxMemoryInMB"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.setFeatureSubsetStrategy"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.setMaxBins"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.setMinInstancesPerNode"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.util.MLWriter.context"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.util.MLReader.context") ) } diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py index 02016f172aeb..688109ab11fd 100644 --- a/python/pyspark/ml/util.py +++ b/python/pyspark/ml/util.py @@ -76,13 +76,6 @@ def overwrite(self): """Overwrites if the output path already exists.""" raise NotImplementedError("MLWriter is not yet implemented for type: %s" % type(self)) - def context(self, sqlContext): - """ - Sets the SQL context to use for saving. - .. note:: Deprecated in 2.1 and will be removed in 2.2, use session instead. - """ - raise NotImplementedError("MLWriter is not yet implemented for type: %s" % type(self)) - def session(self, sparkSession): """Sets the Spark Session to use for saving.""" raise NotImplementedError("MLWriter is not yet implemented for type: %s" % type(self)) @@ -110,15 +103,6 @@ def overwrite(self): self._jwrite.overwrite() return self - def context(self, sqlContext): - """ - Sets the SQL context to use for saving. - .. note:: Deprecated in 2.1 and will be removed in 2.2, use session instead. - """ - warnings.warn("Deprecated in 2.1 and will be removed in 2.2, use session instead.") - self._jwrite.context(sqlContext._ssql_ctx) - return self - def session(self, sparkSession): """Sets the Spark Session to use for saving.""" self._jwrite.session(sparkSession._jsparkSession) @@ -165,13 +149,6 @@ def load(self, path): """Load the ML instance from the input path.""" raise NotImplementedError("MLReader is not yet implemented for type: %s" % type(self)) - def context(self, sqlContext): - """ - Sets the SQL context to use for loading. - .. note:: Deprecated in 2.1 and will be removed in 2.2, use session instead. - """ - raise NotImplementedError("MLReader is not yet implemented for type: %s" % type(self)) - def session(self, sparkSession): """Sets the Spark Session to use for loading.""" raise NotImplementedError("MLReader is not yet implemented for type: %s" % type(self)) @@ -197,15 +174,6 @@ def load(self, path): % self._clazz) return self._clazz._from_java(java_obj) - def context(self, sqlContext): - """ - Sets the SQL context to use for loading. - .. note:: Deprecated in 2.1 and will be removed in 2.2, use session instead. - """ - warnings.warn("Deprecated in 2.1 and will be removed in 2.2, use session instead.") - self._jread.context(sqlContext._ssql_ctx) - return self - def session(self, sparkSession): """Sets the Spark Session to use for loading.""" self._jread.session(sparkSession._jsparkSession) From 0d00c768a860fc03402c8f0c9081b8147c29133e Mon Sep 17 00:00:00 2001 From: Xiao Li Date: Tue, 9 May 2017 20:10:50 +0800 Subject: [PATCH 460/512] [SPARK-20667][SQL][TESTS] Cleanup the cataloged metadata after completing the package of sql/core and sql/hive ## What changes were proposed in this pull request? So far, we do not drop all the cataloged objects after each package. Sometimes, we might hit strange test case errors because the previous test suite did not drop the cataloged/temporary objects (tables/functions/database). At least, we can first clean up the environment when completing the package of `sql/core` and `sql/hive`. ## How was this patch tested? N/A Author: Xiao Li Closes #17908 from gatorsmile/reset. --- .../apache/spark/sql/catalyst/catalog/SessionCatalog.scala | 3 ++- .../scala/org/apache/spark/sql/test/SharedSQLContext.scala | 1 + .../scala/org/apache/spark/sql/hive/test/TestHive.scala | 7 +------ 3 files changed, 4 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index 6c6d600190b6..18e514681e81 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -1251,9 +1251,10 @@ class SessionCatalog( dropTempFunction(func.funcName, ignoreIfNotExists = false) } } - tempTables.clear() + clearTempTables() globalTempViewManager.clear() functionRegistry.clear() + tableRelationCache.invalidateAll() // restore built-in functions FunctionRegistry.builtin.listFunction().foreach { f => val expressionInfo = FunctionRegistry.builtin.lookupFunction(f) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala index 81c69a338abc..7cea4c02155e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala @@ -74,6 +74,7 @@ trait SharedSQLContext extends SQLTestUtils with BeforeAndAfterEach with Eventua protected override def afterAll(): Unit = { super.afterAll() if (_spark != null) { + _spark.sessionState.catalog.reset() _spark.stop() _spark = null } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index d9bb1f8c7edc..ee9ac21a738d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -488,14 +488,9 @@ private[hive] class TestHiveSparkSession( sharedState.cacheManager.clearCache() loadedTables.clear() - sessionState.catalog.clearTempTables() - sessionState.catalog.tableRelationCache.invalidateAll() - + sessionState.catalog.reset() metadataHive.reset() - FunctionRegistry.getFunctionNames.asScala.filterNot(originalUDFs.contains(_)). - foreach { udfName => FunctionRegistry.unregisterTemporaryUDF(udfName) } - // HDFS root scratch dir requires the write all (733) permission. For each connecting user, // an HDFS scratch dir: ${hive.exec.scratchdir}/ is created, with // ${hive.scratch.dir.permission}. To resolve the permission issue, the simplest way is to From 714811d0b5bcb5d47c39782ff74f898d276ecc59 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Tue, 9 May 2017 20:22:51 +0800 Subject: [PATCH 461/512] [SPARK-20311][SQL] Support aliases for table value functions ## What changes were proposed in this pull request? This pr added parsing rules to support aliases in table value functions. ## How was this patch tested? Added tests in `PlanParserSuite`. Author: Takeshi Yamamuro Closes #17666 from maropu/SPARK-20311. --- .../spark/sql/catalyst/parser/SqlBase.g4 | 20 ++++++++++++----- .../ResolveTableValuedFunctions.scala | 22 ++++++++++++++++--- .../sql/catalyst/analysis/unresolved.scala | 10 +++++++-- .../sql/catalyst/parser/AstBuilder.scala | 17 ++++++++++---- .../sql/catalyst/analysis/AnalysisSuite.scala | 14 +++++++++++- .../sql/catalyst/parser/PlanParserSuite.scala | 13 ++++++++++- 6 files changed, 79 insertions(+), 17 deletions(-) diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 14c511f67060..41daf58a98fd 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -472,15 +472,23 @@ identifierComment ; relationPrimary - : tableIdentifier sample? (AS? strictIdentifier)? #tableName - | '(' queryNoWith ')' sample? (AS? strictIdentifier)? #aliasedQuery - | '(' relation ')' sample? (AS? strictIdentifier)? #aliasedRelation - | inlineTable #inlineTableDefault2 - | identifier '(' (expression (',' expression)*)? ')' #tableValuedFunction + : tableIdentifier sample? (AS? strictIdentifier)? #tableName + | '(' queryNoWith ')' sample? (AS? strictIdentifier)? #aliasedQuery + | '(' relation ')' sample? (AS? strictIdentifier)? #aliasedRelation + | inlineTable #inlineTableDefault2 + | functionTable #tableValuedFunction ; inlineTable - : VALUES expression (',' expression)* (AS? identifier identifierList?)? + : VALUES expression (',' expression)* tableAlias + ; + +functionTable + : identifier '(' (expression (',' expression)*)? ')' tableAlias + ; + +tableAlias + : (AS? identifier identifierList?)? ; rowFormat diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala index de6de24350f2..dad1340571cc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala @@ -19,8 +19,8 @@ package org.apache.spark.sql.catalyst.analysis import java.util.Locale -import org.apache.spark.sql.catalyst.expressions.Expression -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Range} +import org.apache.spark.sql.catalyst.expressions.{Alias, Expression} +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, Range} import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.types.{DataType, IntegerType, LongType} @@ -105,7 +105,7 @@ object ResolveTableValuedFunctions extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case u: UnresolvedTableValuedFunction if u.functionArgs.forall(_.resolved) => - builtinFunctions.get(u.functionName.toLowerCase(Locale.ROOT)) match { + val resolvedFunc = builtinFunctions.get(u.functionName.toLowerCase(Locale.ROOT)) match { case Some(tvf) => val resolved = tvf.flatMap { case (argList, resolver) => argList.implicitCast(u.functionArgs) match { @@ -125,5 +125,21 @@ object ResolveTableValuedFunctions extends Rule[LogicalPlan] { case _ => u.failAnalysis(s"could not resolve `${u.functionName}` to a table-valued function") } + + // If alias names assigned, add `Project` with the aliases + if (u.outputNames.nonEmpty) { + val outputAttrs = resolvedFunc.output + // Checks if the number of the aliases is equal to expected one + if (u.outputNames.size != outputAttrs.size) { + u.failAnalysis(s"expected ${outputAttrs.size} columns but " + + s"found ${u.outputNames.size} columns") + } + val aliases = outputAttrs.zip(u.outputNames).map { + case (attr, name) => Alias(attr, name)() + } + Project(aliases, resolvedFunc) + } else { + resolvedFunc + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index 262b894e2a0a..51bef6e20b9f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -66,10 +66,16 @@ case class UnresolvedInlineTable( /** * A table-valued function, e.g. * {{{ - * select * from range(10); + * select id from range(10); + * + * // Assign alias names + * select t.a from range(10) t(a); * }}} */ -case class UnresolvedTableValuedFunction(functionName: String, functionArgs: Seq[Expression]) +case class UnresolvedTableValuedFunction( + functionName: String, + functionArgs: Seq[Expression], + outputNames: Seq[String]) extends LeafNode { override def output: Seq[Attribute] = Nil diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index d2a9b4a9a9f5..e03fe2ccb8d8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -687,7 +687,16 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { */ override def visitTableValuedFunction(ctx: TableValuedFunctionContext) : LogicalPlan = withOrigin(ctx) { - UnresolvedTableValuedFunction(ctx.identifier.getText, ctx.expression.asScala.map(expression)) + val func = ctx.functionTable + val aliases = if (func.tableAlias.identifierList != null) { + visitIdentifierList(func.tableAlias.identifierList) + } else { + Seq.empty + } + + val tvf = UnresolvedTableValuedFunction( + func.identifier.getText, func.expression.asScala.map(expression), aliases) + tvf.optionalMap(func.tableAlias.identifier)(aliasPlan) } /** @@ -705,14 +714,14 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { } } - val aliases = if (ctx.identifierList != null) { - visitIdentifierList(ctx.identifierList) + val aliases = if (ctx.tableAlias.identifierList != null) { + visitIdentifierList(ctx.tableAlias.identifierList) } else { Seq.tabulate(rows.head.size)(i => s"col${i + 1}") } val table = UnresolvedInlineTable(aliases, rows) - table.optionalMap(ctx.identifier)(aliasPlan) + table.optionalMap(ctx.tableAlias.identifier)(aliasPlan) } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 893bb1b74cea..31047f688600 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -25,7 +25,6 @@ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.catalyst.plans.Cross import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.types._ @@ -441,4 +440,17 @@ class AnalysisSuite extends AnalysisTest with ShouldMatchers { checkAnalysis(SubqueryAlias("tbl", testRelation).as("tbl2"), testRelation) } + + test("SPARK-20311 range(N) as alias") { + def rangeWithAliases(args: Seq[Int], outputNames: Seq[String]): LogicalPlan = { + SubqueryAlias("t", UnresolvedTableValuedFunction("range", args.map(Literal(_)), outputNames)) + .select(star()) + } + assertAnalysisSuccess(rangeWithAliases(3 :: Nil, "a" :: Nil)) + assertAnalysisSuccess(rangeWithAliases(1 :: 4 :: Nil, "b" :: Nil)) + assertAnalysisSuccess(rangeWithAliases(2 :: 6 :: 2 :: Nil, "c" :: Nil)) + assertAnalysisError( + rangeWithAliases(3 :: Nil, "a" :: "b" :: Nil), + Seq("expected 1 columns but found 2 columns")) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala index 411777d6e85a..4c2476296c04 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -468,7 +468,18 @@ class PlanParserSuite extends PlanTest { test("table valued function") { assertEqual( "select * from range(2)", - UnresolvedTableValuedFunction("range", Literal(2) :: Nil).select(star())) + UnresolvedTableValuedFunction("range", Literal(2) :: Nil, Seq.empty).select(star())) + } + + test("SPARK-20311 range(N) as alias") { + assertEqual( + "select * from range(10) AS t", + SubqueryAlias("t", UnresolvedTableValuedFunction("range", Literal(10) :: Nil, Seq.empty)) + .select(star())) + assertEqual( + "select * from range(7) AS t(a)", + SubqueryAlias("t", UnresolvedTableValuedFunction("range", Literal(7) :: Nil, "a" :: Nil)) + .select(star())) } test("inline table") { From 181261a81d592b93181135a8267570e0c9ab2243 Mon Sep 17 00:00:00 2001 From: Sanket Date: Tue, 9 May 2017 09:30:09 -0500 Subject: [PATCH 462/512] [SPARK-20355] Add per application spark version on the history server headerpage ## What changes were proposed in this pull request? Spark Version for a specific application is not displayed on the history page now. It should be nice to switch the spark version on the UI when we click on the specific application. Currently there seems to be way as SparkListenerLogStart records the application version. So, it should be trivial to listen to this event and provision this change on the UI. For Example screen shot 2017-04-06 at 3 23 41 pm screen shot 2017-04-17 at 9 59 33 am {"Event":"SparkListenerLogStart","Spark Version":"2.0.0"} (Please fill in changes proposed in this fix) Modified the SparkUI for History server to listen to SparkLogListenerStart event and extract the version and print it. ## How was this patch tested? Manual testing of UI page. Attaching the UI screenshot changes here (Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) (If this patch involves UI changes, please attach a screenshot; otherwise, remove this) Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Sanket Closes #17658 from redsanket/SPARK-20355. --- .../history/ApplicationHistoryProvider.scala | 3 ++- .../deploy/history/FsHistoryProvider.scala | 17 ++++++++++++----- .../scheduler/ApplicationEventListener.scala | 7 +++++++ .../spark/scheduler/EventLoggingListener.scala | 13 ++++++++++--- .../apache/spark/scheduler/SparkListener.scala | 4 ++-- .../spark/scheduler/SparkListenerBus.scala | 1 - .../status/api/v1/ApplicationListResource.scala | 3 ++- .../org/apache/spark/status/api/v1/api.scala | 3 ++- .../scala/org/apache/spark/ui/SparkUI.scala | 6 +++++- .../scala/org/apache/spark/ui/UIUtils.scala | 2 +- .../application_list_json_expectation.json | 10 ++++++++++ .../completed_app_list_json_expectation.json | 11 +++++++++++ .../limit_app_list_json_expectation.json | 3 +++ .../maxDate2_app_list_json_expectation.json | 1 + .../maxDate_app_list_json_expectation.json | 2 ++ .../maxEndDate_app_list_json_expectation.json | 7 +++++++ ...nd_maxEndDate_app_list_json_expectation.json | 4 ++++ .../minDate_app_list_json_expectation.json | 8 ++++++++ ...nd_maxEndDate_app_list_json_expectation.json | 4 ++++ .../minEndDate_app_list_json_expectation.json | 6 +++++- .../one_app_json_expectation.json | 1 + .../one_app_multi_attempt_json_expectation.json | 2 ++ .../deploy/history/ApplicationCacheSuite.scala | 2 +- .../deploy/history/FsHistoryProviderSuite.scala | 4 ++-- project/MimaExcludes.scala | 3 +++ 25 files changed, 107 insertions(+), 20 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala index 6d8758a3d3b1..5cb48ca3e60b 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala @@ -30,7 +30,8 @@ private[spark] case class ApplicationAttemptInfo( endTime: Long, lastUpdated: Long, sparkUser: String, - completed: Boolean = false) + completed: Boolean = false, + appSparkVersion: String) private[spark] case class ApplicationHistoryInfo( id: String, diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index f4235df24512..d05ca142b618 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -248,7 +248,8 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) val conf = this.conf.clone() val appSecManager = new SecurityManager(conf) SparkUI.createHistoryUI(conf, replayBus, appSecManager, appInfo.name, - HistoryServer.getAttemptURI(appId, attempt.attemptId), attempt.startTime) + HistoryServer.getAttemptURI(appId, attempt.attemptId), + attempt.startTime) // Do not call ui.bind() to avoid creating a new server for each application } @@ -257,6 +258,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) val appListener = replay(fileStatus, isApplicationCompleted(fileStatus), replayBus) if (appListener.appId.isDefined) { + ui.appSparkVersion = appListener.appSparkVersion.getOrElse("") ui.getSecurityManager.setAcls(HISTORY_UI_ACLS_ENABLE) // make sure to set admin acls before view acls so they are properly picked up val adminAcls = HISTORY_UI_ADMIN_ACLS + "," + appListener.adminAcls.getOrElse("") @@ -443,7 +445,8 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) val newAttempts = try { val eventsFilter: ReplayEventsFilter = { eventString => eventString.startsWith(APPL_START_EVENT_PREFIX) || - eventString.startsWith(APPL_END_EVENT_PREFIX) + eventString.startsWith(APPL_END_EVENT_PREFIX) || + eventString.startsWith(LOG_START_EVENT_PREFIX) } val logPath = fileStatus.getPath() @@ -469,7 +472,8 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) lastUpdated, appListener.sparkUser.getOrElse(NOT_STARTED), appCompleted, - fileStatus.getLen() + fileStatus.getLen(), + appListener.appSparkVersion.getOrElse("") ) fileToAppInfo(logPath) = attemptInfo logDebug(s"Application log ${attemptInfo.logPath} loaded successfully: $attemptInfo") @@ -735,6 +739,8 @@ private[history] object FsHistoryProvider { private val APPL_START_EVENT_PREFIX = "{\"Event\":\"SparkListenerApplicationStart\"" private val APPL_END_EVENT_PREFIX = "{\"Event\":\"SparkListenerApplicationEnd\"" + + private val LOG_START_EVENT_PREFIX = "{\"Event\":\"SparkListenerLogStart\"" } /** @@ -762,9 +768,10 @@ private class FsApplicationAttemptInfo( lastUpdated: Long, sparkUser: String, completed: Boolean, - val fileSize: Long) + val fileSize: Long, + appSparkVersion: String) extends ApplicationAttemptInfo( - attemptId, startTime, endTime, lastUpdated, sparkUser, completed) { + attemptId, startTime, endTime, lastUpdated, sparkUser, completed, appSparkVersion) { /** extend the superclass string value with the extra attributes of this class */ override def toString: String = { diff --git a/core/src/main/scala/org/apache/spark/scheduler/ApplicationEventListener.scala b/core/src/main/scala/org/apache/spark/scheduler/ApplicationEventListener.scala index 28c45d800ed0..6da8865cd10d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ApplicationEventListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ApplicationEventListener.scala @@ -34,6 +34,7 @@ private[spark] class ApplicationEventListener extends SparkListener { var adminAcls: Option[String] = None var viewAclsGroups: Option[String] = None var adminAclsGroups: Option[String] = None + var appSparkVersion: Option[String] = None override def onApplicationStart(applicationStart: SparkListenerApplicationStart) { appName = Some(applicationStart.appName) @@ -57,4 +58,10 @@ private[spark] class ApplicationEventListener extends SparkListener { adminAclsGroups = allProperties.get("spark.admin.acls.groups") } } + + override def onOtherEvent(event: SparkListenerEvent): Unit = event match { + case SparkListenerLogStart(sparkVersion) => + appSparkVersion = Some(sparkVersion) + case _ => + } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala index a7dbf87915b2..f48143633224 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala @@ -119,7 +119,7 @@ private[spark] class EventLoggingListener( val cstream = compressionCodec.map(_.compressedOutputStream(dstream)).getOrElse(dstream) val bstream = new BufferedOutputStream(cstream, outputBufferSize) - EventLoggingListener.initEventLog(bstream) + EventLoggingListener.initEventLog(bstream, testing, loggedEvents) fileSystem.setPermission(path, LOG_FILE_PERMISSIONS) writer = Some(new PrintWriter(bstream)) logInfo("Logging events to %s".format(logPath)) @@ -283,10 +283,17 @@ private[spark] object EventLoggingListener extends Logging { * * @param logStream Raw output stream to the event log file. */ - def initEventLog(logStream: OutputStream): Unit = { + def initEventLog( + logStream: OutputStream, + testing: Boolean, + loggedEvents: ArrayBuffer[JValue]): Unit = { val metadata = SparkListenerLogStart(SPARK_VERSION) - val metadataJson = compact(JsonProtocol.logStartToJson(metadata)) + "\n" + val eventJson = JsonProtocol.logStartToJson(metadata) + val metadataJson = compact(eventJson) + "\n" logStream.write(metadataJson.getBytes(StandardCharsets.UTF_8)) + if (testing && loggedEvents != null) { + loggedEvents += eventJson + } } /** diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala index bc2e53071668..59f89a82a1da 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala @@ -160,9 +160,9 @@ case class SparkListenerApplicationEnd(time: Long) extends SparkListenerEvent /** * An internal class that describes the metadata of an event log. - * This event is not meant to be posted to listeners downstream. */ -private[spark] case class SparkListenerLogStart(sparkVersion: String) extends SparkListenerEvent +@DeveloperApi +case class SparkListenerLogStart(sparkVersion: String) extends SparkListenerEvent /** * Interface for creating history listeners defined in other modules like SQL, which are used to diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala index 3ff363321e8c..3b0d3b1b150f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala @@ -71,7 +71,6 @@ private[spark] trait SparkListenerBus listener.onNodeUnblacklisted(nodeUnblacklisted) case blockUpdated: SparkListenerBlockUpdated => listener.onBlockUpdated(blockUpdated) - case logStart: SparkListenerLogStart => // ignore event log metadata case _ => listener.onOtherEvent(event) } } diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/ApplicationListResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/ApplicationListResource.scala index a0239266d875..f039744e7f67 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/ApplicationListResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/ApplicationListResource.scala @@ -90,7 +90,8 @@ private[spark] object ApplicationsListResource { }, lastUpdated = new Date(internalAttemptInfo.lastUpdated), sparkUser = internalAttemptInfo.sparkUser, - completed = internalAttemptInfo.completed + completed = internalAttemptInfo.completed, + appSparkVersion = internalAttemptInfo.appSparkVersion ) } ) diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala index 56d8e51732ff..f6203271f3cd 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala @@ -38,7 +38,8 @@ class ApplicationAttemptInfo private[spark]( val lastUpdated: Date, val duration: Long, val sparkUser: String, - val completed: Boolean = false) { + val completed: Boolean = false, + val appSparkVersion: String) { def getStartTimeEpoch: Long = startTime.getTime def getEndTimeEpoch: Long = endTime.getTime def getLastUpdatedEpoch: Long = lastUpdated.getTime diff --git a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala index bf4cf79e9faa..f271c56021e9 100644 --- a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala @@ -60,6 +60,8 @@ private[spark] class SparkUI private ( var appId: String = _ + var appSparkVersion = org.apache.spark.SPARK_VERSION + private var streamingJobProgressListener: Option[SparkListener] = None /** Initialize all components of the server. */ @@ -118,7 +120,8 @@ private[spark] class SparkUI private ( duration = 0, lastUpdated = new Date(startTime), sparkUser = getSparkUser, - completed = false + completed = false, + appSparkVersion = appSparkVersion )) )) } @@ -139,6 +142,7 @@ private[spark] abstract class SparkUITab(parent: SparkUI, prefix: String) def appName: String = parent.appName + def appSparkVersion: String = parent.appSparkVersion } private[spark] object SparkUI { diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala index 79b0d81af52b..8e1aafa448bc 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala @@ -228,7 +228,7 @@ private[spark] object UIUtils extends Logging {